From f639c4c3d14f19128c244610e93c0589d174aa0a Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 11 Nov 2022 10:41:44 +0100 Subject: [PATCH 001/109] Add schema for reconnection support --- .../20221109000000_test_schema.sql | 56 +++++++++++++++++-- .../20221111092550_reconnection_support.sql | 47 ++++++++++++++++ 2 files changed, 97 insertions(+), 6 deletions(-) create mode 100644 crates/collab/migrations/20221111092550_reconnection_support.sql diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 63d2661de5d5d2262b371de651b434f6fe1a6c38..731910027e1ef362e0cbe36ceec7ca0e7f5c0f88 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -1,4 +1,4 @@ -CREATE TABLE IF NOT EXISTS "users" ( +CREATE TABLE "users" ( "id" INTEGER PRIMARY KEY, "github_login" VARCHAR, "admin" BOOLEAN, @@ -16,14 +16,14 @@ CREATE UNIQUE INDEX "index_invite_code_users" ON "users" ("invite_code"); CREATE INDEX "index_users_on_email_address" ON "users" ("email_address"); CREATE INDEX "index_users_on_github_user_id" ON "users" ("github_user_id"); -CREATE TABLE IF NOT EXISTS "access_tokens" ( +CREATE TABLE "access_tokens" ( "id" INTEGER PRIMARY KEY, "user_id" INTEGER REFERENCES users (id), "hash" VARCHAR(128) ); CREATE INDEX "index_access_tokens_user_id" ON "access_tokens" ("user_id"); -CREATE TABLE IF NOT EXISTS "contacts" ( +CREATE TABLE "contacts" ( "id" INTEGER PRIMARY KEY, "user_id_a" INTEGER REFERENCES users (id) NOT NULL, "user_id_b" INTEGER REFERENCES users (id) NOT NULL, @@ -34,8 +34,52 @@ CREATE TABLE IF NOT EXISTS "contacts" ( CREATE UNIQUE INDEX "index_contacts_user_ids" ON "contacts" ("user_id_a", "user_id_b"); CREATE INDEX "index_contacts_user_id_b" ON "contacts" ("user_id_b"); -CREATE TABLE IF NOT EXISTS "projects" ( +CREATE TABLE "rooms" ( "id" INTEGER PRIMARY KEY, - "host_user_id" INTEGER REFERENCES users (id) NOT NULL, - "unregistered" BOOLEAN NOT NULL DEFAULT false + "version" INTEGER NOT NULL, + "live_kit_room" VARCHAR NOT NULL ); + +CREATE TABLE "projects" ( + "id" INTEGER PRIMARY KEY, + "room_id" INTEGER REFERENCES rooms (id), + "host_user_id" INTEGER REFERENCES users (id) NOT NULL +); + +CREATE TABLE "project_collaborators" ( + "id" INTEGER PRIMARY KEY, + "project_id" INTEGER NOT NULL REFERENCES projects (id), + "connection_id" INTEGER NOT NULL, + "user_id" INTEGER NOT NULL, + "replica_id" INTEGER NOT NULL, + "is_host" BOOLEAN NOT NULL +); +CREATE INDEX "index_project_collaborators_on_project_id" ON "project_collaborators" ("project_id"); + +CREATE TABLE "worktrees" ( + "id" INTEGER NOT NULL, + "project_id" INTEGER NOT NULL REFERENCES projects (id), + "root_name" VARCHAR NOT NULL, + PRIMARY KEY(project_id, id) +); +CREATE INDEX "index_worktrees_on_project_id" ON "worktrees" ("project_id"); + +CREATE TABLE "room_participants" ( + "id" INTEGER PRIMARY KEY, + "room_id" INTEGER NOT NULL REFERENCES rooms (id), + "user_id" INTEGER NOT NULL REFERENCES users (id), + "connection_id" INTEGER, + "location_kind" INTEGER, + "location_project_id" INTEGER REFERENCES projects (id) +); +CREATE UNIQUE INDEX "index_room_participants_on_user_id_and_room_id" ON "room_participants" ("user_id", "room_id"); + +CREATE TABLE "calls" ( + "id" INTEGER PRIMARY KEY, + "room_id" INTEGER NOT NULL REFERENCES rooms (id), + "calling_user_id" INTEGER NOT NULL REFERENCES users (id), + "called_user_id" INTEGER NOT NULL REFERENCES users (id), + "answering_connection_id" INTEGER, + "initial_project_id" INTEGER REFERENCES projects (id) +); +CREATE UNIQUE INDEX "index_calls_on_calling_user_id" ON "calls" ("calling_user_id"); diff --git a/crates/collab/migrations/20221111092550_reconnection_support.sql b/crates/collab/migrations/20221111092550_reconnection_support.sql new file mode 100644 index 0000000000000000000000000000000000000000..9474beff4296215588344b22368f7aecdd36006a --- /dev/null +++ b/crates/collab/migrations/20221111092550_reconnection_support.sql @@ -0,0 +1,47 @@ +CREATE TABLE IF NOT EXISTS "rooms" ( + "id" SERIAL PRIMARY KEY, + "version" INTEGER NOT NULL, + "live_kit_room" VARCHAR NOT NULL +); + +ALTER TABLE "projects" + ADD "room_id" INTEGER REFERENCES rooms (id), + DROP COLUMN "unregistered"; + +CREATE TABLE "project_collaborators" ( + "id" SERIAL PRIMARY KEY, + "project_id" INTEGER NOT NULL REFERENCES projects (id), + "connection_id" INTEGER NOT NULL, + "user_id" INTEGER NOT NULL, + "replica_id" INTEGER NOT NULL, + "is_host" BOOLEAN NOT NULL +); +CREATE INDEX "index_project_collaborators_on_project_id" ON "project_collaborators" ("project_id"); + +CREATE TABLE IF NOT EXISTS "worktrees" ( + "id" INTEGER NOT NULL, + "project_id" INTEGER NOT NULL REFERENCES projects (id), + "root_name" VARCHAR NOT NULL, + PRIMARY KEY(project_id, id) +); +CREATE INDEX "index_worktrees_on_project_id" ON "worktrees" ("project_id"); + +CREATE TABLE IF NOT EXISTS "room_participants" ( + "id" SERIAL PRIMARY KEY, + "room_id" INTEGER NOT NULL REFERENCES rooms (id), + "user_id" INTEGER NOT NULL REFERENCES users (id), + "connection_id" INTEGER, + "location_kind" INTEGER, + "location_project_id" INTEGER REFERENCES projects (id) +); +CREATE UNIQUE INDEX "index_room_participants_on_user_id_and_room_id" ON "room_participants" ("user_id", "room_id"); + +CREATE TABLE IF NOT EXISTS "calls" ( + "id" SERIAL PRIMARY KEY, + "room_id" INTEGER NOT NULL REFERENCES rooms (id), + "calling_user_id" INTEGER NOT NULL REFERENCES users (id), + "called_user_id" INTEGER NOT NULL REFERENCES users (id), + "answering_connection_id" INTEGER, + "initial_project_id" INTEGER REFERENCES projects (id) +); +CREATE UNIQUE INDEX "index_calls_on_calling_user_id" ON "calls" ("calling_user_id"); From 28aa1567ce8d814a9a3ffbcd1b566a1b343907d4 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 7 Nov 2022 15:40:02 +0100 Subject: [PATCH 002/109] Include `sender_user_id` when handling a server message/request --- crates/collab/src/rpc.rs | 465 +++++++++++++++++++++++---------------- 1 file changed, 276 insertions(+), 189 deletions(-) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 7bc2b43b9b4c24cdf991e88c90a4e966927a8cfd..757c765838551666613445121bfc4625ad89a2e6 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -68,8 +68,15 @@ lazy_static! { .unwrap(); } -type MessageHandler = - Box, Box) -> BoxFuture<'static, ()>>; +type MessageHandler = Box< + dyn Send + Sync + Fn(Arc, UserId, Box) -> BoxFuture<'static, ()>, +>; + +struct Message { + sender_user_id: UserId, + sender_connection_id: ConnectionId, + payload: T, +} struct Response { server: Arc, @@ -193,15 +200,15 @@ impl Server { Arc::new(server) } - fn add_message_handler(&mut self, handler: F) -> &mut Self + fn add_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(Arc, TypedEnvelope) -> Fut, + F: 'static + Send + Sync + Fn(Arc, UserId, TypedEnvelope) -> Fut, Fut: 'static + Send + Future>, M: EnvelopedMessage, { let prev_handler = self.handlers.insert( TypeId::of::(), - Box::new(move |server, envelope| { + Box::new(move |server, sender_user_id, envelope| { let envelope = envelope.into_any().downcast::>().unwrap(); let span = info_span!( "handle message", @@ -213,7 +220,7 @@ impl Server { "message received" ); }); - let future = (handler)(server, *envelope); + let future = (handler)(server, sender_user_id, *envelope); async move { if let Err(error) = future.await { tracing::error!(%error, "error handling message"); @@ -229,26 +236,50 @@ impl Server { self } + fn add_message_handler(&mut self, handler: F) -> &mut Self + where + F: 'static + Send + Sync + Fn(Arc, Message) -> Fut, + Fut: 'static + Send + Future>, + M: EnvelopedMessage, + { + self.add_handler(move |server, sender_user_id, envelope| { + handler( + server, + Message { + sender_user_id, + sender_connection_id: envelope.sender_id, + payload: envelope.payload, + }, + ) + }); + self + } + /// Handle a request while holding a lock to the store. This is useful when we're registering /// a connection but we want to respond on the connection before anybody else can send on it. fn add_request_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(Arc, TypedEnvelope, Response) -> Fut, + F: 'static + Send + Sync + Fn(Arc, Message, Response) -> Fut, Fut: Send + Future>, M: RequestMessage, { let handler = Arc::new(handler); - self.add_message_handler(move |server, envelope| { + self.add_handler(move |server, sender_user_id, envelope| { let receipt = envelope.receipt(); let handler = handler.clone(); async move { + let request = Message { + sender_user_id, + sender_connection_id: envelope.sender_id, + payload: envelope.payload, + }; let responded = Arc::new(AtomicBool::default()); let response = Response { server: server.clone(), responded: responded.clone(), - receipt: envelope.receipt(), + receipt, }; - match (handler)(server.clone(), envelope, response).await { + match (handler)(server.clone(), request, response).await { Ok(()) => { if responded.load(std::sync::atomic::Ordering::SeqCst) { Ok(()) @@ -361,7 +392,7 @@ impl Server { let span_enter = span.enter(); if let Some(handler) = this.handlers.get(&message.payload_type_id()) { let is_background = message.is_background(); - let handle_message = (handler)(this.clone(), message); + let handle_message = (handler)(this.clone(), user_id, message); drop(span_enter); let handle_message = handle_message.instrument(span); @@ -516,7 +547,7 @@ impl Server { async fn ping( self: Arc, - _: TypedEnvelope, + _: Message, response: Response, ) -> Result<()> { response.send(proto::Ack {})?; @@ -525,15 +556,13 @@ impl Server { async fn create_room( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { - let user_id; let room; { let mut store = self.store().await; - user_id = store.user_id_for_connection(request.sender_id)?; - room = store.create_room(request.sender_id)?.clone(); + room = store.create_room(request.sender_connection_id)?.clone(); } let live_kit_connection_info = @@ -544,7 +573,10 @@ impl Server { .trace_err() { if let Some(token) = live_kit - .room_token(&room.live_kit_room, &request.sender_id.to_string()) + .room_token( + &room.live_kit_room, + &request.sender_connection_id.to_string(), + ) .trace_err() { Some(proto::LiveKitConnectionInfo { @@ -565,21 +597,19 @@ impl Server { room: Some(room), live_kit_connection_info, })?; - self.update_user_contacts(user_id).await?; + self.update_user_contacts(request.sender_user_id).await?; Ok(()) } async fn join_room( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { - let user_id; { let mut store = self.store().await; - user_id = store.user_id_for_connection(request.sender_id)?; let (room, recipient_connection_ids) = - store.join_room(request.payload.id, request.sender_id)?; + store.join_room(request.payload.id, request.sender_connection_id)?; for recipient_id in recipient_connection_ids { self.peer .send(recipient_id, proto::CallCanceled {}) @@ -589,7 +619,10 @@ impl Server { let live_kit_connection_info = if let Some(live_kit) = self.app_state.live_kit_client.as_ref() { if let Some(token) = live_kit - .room_token(&room.live_kit_room, &request.sender_id.to_string()) + .room_token( + &room.live_kit_room, + &request.sender_connection_id.to_string(), + ) .trace_err() { Some(proto::LiveKitConnectionInfo { @@ -609,18 +642,17 @@ impl Server { })?; self.room_updated(room); } - self.update_user_contacts(user_id).await?; + self.update_user_contacts(request.sender_user_id).await?; Ok(()) } - async fn leave_room(self: Arc, message: TypedEnvelope) -> Result<()> { + async fn leave_room(self: Arc, message: Message) -> Result<()> { let mut contacts_to_update = HashSet::default(); let room_left; { let mut store = self.store().await; - let user_id = store.user_id_for_connection(message.sender_id)?; - let left_room = store.leave_room(message.payload.id, message.sender_id)?; - contacts_to_update.insert(user_id); + let left_room = store.leave_room(message.payload.id, message.sender_connection_id)?; + contacts_to_update.insert(message.sender_user_id); for project in left_room.unshared_projects { for connection_id in project.connection_ids() { @@ -640,13 +672,13 @@ impl Server { connection_id, proto::RemoveProjectCollaborator { project_id: project.id.to_proto(), - peer_id: message.sender_id.0, + peer_id: message.sender_connection_id.0, }, )?; } self.peer.send( - message.sender_id, + message.sender_connection_id, proto::UnshareProject { project_id: project.id.to_proto(), }, @@ -655,7 +687,7 @@ impl Server { } self.room_updated(&left_room.room); - room_left = self.room_left(&left_room.room, message.sender_id); + room_left = self.room_left(&left_room.room, message.sender_connection_id); for connection_id in left_room.canceled_call_connection_ids { self.peer @@ -675,13 +707,10 @@ impl Server { async fn call( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { - let caller_user_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; + let caller_user_id = request.sender_user_id; let recipient_user_id = UserId::from_proto(request.payload.recipient_user_id); let initial_project_id = request .payload @@ -703,7 +732,7 @@ impl Server { room_id, recipient_user_id, initial_project_id, - request.sender_id, + request.sender_connection_id, )?; self.room_updated(room); recipient_connection_ids @@ -740,7 +769,7 @@ impl Server { async fn cancel_call( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { let recipient_user_id = UserId::from_proto(request.payload.recipient_user_id); @@ -749,7 +778,7 @@ impl Server { let (room, recipient_connection_ids) = store.cancel_call( request.payload.room_id, recipient_user_id, - request.sender_id, + request.sender_connection_id, )?; for recipient_id in recipient_connection_ids { self.peer @@ -763,16 +792,12 @@ impl Server { Ok(()) } - async fn decline_call( - self: Arc, - message: TypedEnvelope, - ) -> Result<()> { - let recipient_user_id; + async fn decline_call(self: Arc, message: Message) -> Result<()> { + let recipient_user_id = message.sender_user_id; { let mut store = self.store().await; - recipient_user_id = store.user_id_for_connection(message.sender_id)?; let (room, recipient_connection_ids) = - store.decline_call(message.payload.room_id, message.sender_id)?; + store.decline_call(message.payload.room_id, message.sender_connection_id)?; for recipient_id in recipient_connection_ids { self.peer .send(recipient_id, proto::CallCanceled {}) @@ -786,7 +811,7 @@ impl Server { async fn update_participant_location( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { let room_id = request.payload.room_id; @@ -795,7 +820,8 @@ impl Server { .location .ok_or_else(|| anyhow!("invalid location"))?; let mut store = self.store().await; - let room = store.update_participant_location(room_id, location, request.sender_id)?; + let room = + store.update_participant_location(room_id, location, request.sender_connection_id)?; self.room_updated(room); response.send(proto::Ack {})?; Ok(()) @@ -839,20 +865,20 @@ impl Server { async fn share_project( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { - let user_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; - let project_id = self.app_state.db.register_project(user_id).await?; + let project_id = self + .app_state + .db + .register_project(request.sender_user_id) + .await?; let mut store = self.store().await; let room = store.share_project( request.payload.room_id, project_id, request.payload.worktrees, - request.sender_id, + request.sender_connection_id, )?; response.send(proto::ShareProjectResponse { project_id: project_id.to_proto(), @@ -864,13 +890,13 @@ impl Server { async fn unshare_project( self: Arc, - message: TypedEnvelope, + message: Message, ) -> Result<()> { let project_id = ProjectId::from_proto(message.payload.project_id); let mut store = self.store().await; - let (room, project) = store.unshare_project(project_id, message.sender_id)?; + let (room, project) = store.unshare_project(project_id, message.sender_connection_id)?; broadcast( - message.sender_id, + message.sender_connection_id, project.guest_connection_ids(), |conn_id| self.peer.send(conn_id, message.payload.clone()), ); @@ -911,26 +937,24 @@ impl Server { async fn join_project( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); - + let guest_user_id = request.sender_user_id; let host_user_id; - let guest_user_id; let host_connection_id; { let state = self.store().await; let project = state.project(project_id)?; host_user_id = project.host.user_id; host_connection_id = project.host_connection_id; - guest_user_id = state.user_id_for_connection(request.sender_id)?; }; tracing::info!(%project_id, %host_user_id, %host_connection_id, "join project"); let mut store = self.store().await; - let (project, replica_id) = store.join_project(request.sender_id, project_id)?; + let (project, replica_id) = store.join_project(request.sender_connection_id, project_id)?; let peer_count = project.guests.len(); let mut collaborators = Vec::with_capacity(peer_count); collaborators.push(proto::Collaborator { @@ -951,7 +975,7 @@ impl Server { // Add all guests other than the requesting user's own connections as collaborators for (guest_conn_id, guest) in &project.guests { - if request.sender_id != *guest_conn_id { + if request.sender_connection_id != *guest_conn_id { collaborators.push(proto::Collaborator { peer_id: guest_conn_id.0, replica_id: guest.replica_id as u32, @@ -961,14 +985,14 @@ impl Server { } for conn_id in project.connection_ids() { - if conn_id != request.sender_id { + if conn_id != request.sender_connection_id { self.peer .send( conn_id, proto::AddProjectCollaborator { project_id: project_id.to_proto(), collaborator: Some(proto::Collaborator { - peer_id: request.sender_id.0, + peer_id: request.sender_connection_id.0, replica_id: replica_id as u32, user_id: guest_user_id.to_proto(), }), @@ -1004,13 +1028,14 @@ impl Server { is_last_update: worktree.is_complete, }; for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) { - self.peer.send(request.sender_id, update.clone())?; + self.peer + .send(request.sender_connection_id, update.clone())?; } // Stream this worktree's diagnostics. for summary in worktree.diagnostic_summaries.values() { self.peer.send( - request.sender_id, + request.sender_connection_id, proto::UpdateDiagnosticSummary { project_id: project_id.to_proto(), worktree_id: *worktree_id, @@ -1022,7 +1047,7 @@ impl Server { for language_server in &project.language_servers { self.peer.send( - request.sender_id, + request.sender_connection_id, proto::UpdateLanguageServer { project_id: project_id.to_proto(), language_server_id: language_server.id, @@ -1038,11 +1063,8 @@ impl Server { Ok(()) } - async fn leave_project( - self: Arc, - request: TypedEnvelope, - ) -> Result<()> { - let sender_id = request.sender_id; + async fn leave_project(self: Arc, request: Message) -> Result<()> { + let sender_id = request.sender_connection_id; let project_id = ProjectId::from_proto(request.payload.project_id); let project; { @@ -1073,20 +1095,30 @@ impl Server { async fn update_project( self: Arc, - request: TypedEnvelope, + request: Message, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); { let mut state = self.store().await; let guest_connection_ids = state - .read_project(project_id, request.sender_id)? + .read_project(project_id, request.sender_connection_id)? .guest_connection_ids(); - let room = - state.update_project(project_id, &request.payload.worktrees, request.sender_id)?; - broadcast(request.sender_id, guest_connection_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); + let room = state.update_project( + project_id, + &request.payload.worktrees, + request.sender_connection_id, + )?; + broadcast( + request.sender_connection_id, + guest_connection_ids, + |connection_id| { + self.peer.forward_send( + request.sender_connection_id, + connection_id, + request.payload.clone(), + ) + }, + ); self.room_updated(room); }; @@ -1095,13 +1127,13 @@ impl Server { async fn update_worktree( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); let worktree_id = request.payload.worktree_id; let connection_ids = self.store().await.update_worktree( - request.sender_id, + request.sender_connection_id, project_id, worktree_id, &request.payload.root_name, @@ -1111,17 +1143,24 @@ impl Server { request.payload.is_last_update, )?; - broadcast(request.sender_id, connection_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); + broadcast( + request.sender_connection_id, + connection_ids, + |connection_id| { + self.peer.forward_send( + request.sender_connection_id, + connection_id, + request.payload.clone(), + ) + }, + ); response.send(proto::Ack {})?; Ok(()) } async fn update_diagnostic_summary( self: Arc, - request: TypedEnvelope, + request: Message, ) -> Result<()> { let summary = request .payload @@ -1131,55 +1170,76 @@ impl Server { let receiver_ids = self.store().await.update_diagnostic_summary( ProjectId::from_proto(request.payload.project_id), request.payload.worktree_id, - request.sender_id, + request.sender_connection_id, summary, )?; - broadcast(request.sender_id, receiver_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); + broadcast( + request.sender_connection_id, + receiver_ids, + |connection_id| { + self.peer.forward_send( + request.sender_connection_id, + connection_id, + request.payload.clone(), + ) + }, + ); Ok(()) } async fn start_language_server( self: Arc, - request: TypedEnvelope, + request: Message, ) -> Result<()> { let receiver_ids = self.store().await.start_language_server( ProjectId::from_proto(request.payload.project_id), - request.sender_id, + request.sender_connection_id, request .payload .server .clone() .ok_or_else(|| anyhow!("invalid language server"))?, )?; - broadcast(request.sender_id, receiver_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); + broadcast( + request.sender_connection_id, + receiver_ids, + |connection_id| { + self.peer.forward_send( + request.sender_connection_id, + connection_id, + request.payload.clone(), + ) + }, + ); Ok(()) } async fn update_language_server( self: Arc, - request: TypedEnvelope, + request: Message, ) -> Result<()> { let receiver_ids = self.store().await.project_connection_ids( ProjectId::from_proto(request.payload.project_id), - request.sender_id, + request.sender_connection_id, )?; - broadcast(request.sender_id, receiver_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); + broadcast( + request.sender_connection_id, + receiver_ids, + |connection_id| { + self.peer.forward_send( + request.sender_connection_id, + connection_id, + request.payload.clone(), + ) + }, + ); Ok(()) } async fn forward_project_request( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> where @@ -1189,17 +1249,21 @@ impl Server { let host_connection_id = self .store() .await - .read_project(project_id, request.sender_id)? + .read_project(project_id, request.sender_connection_id)? .host_connection_id; let payload = self .peer - .forward_request(request.sender_id, host_connection_id, request.payload) + .forward_request( + request.sender_connection_id, + host_connection_id, + request.payload, + ) .await?; // Ensure project still exists by the time we get the response from the host. self.store() .await - .read_project(project_id, request.sender_id)?; + .read_project(project_id, request.sender_connection_id)?; response.send(payload)?; Ok(()) @@ -1207,26 +1271,26 @@ impl Server { async fn save_buffer( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); let host = self .store() .await - .read_project(project_id, request.sender_id)? + .read_project(project_id, request.sender_connection_id)? .host_connection_id; let response_payload = self .peer - .forward_request(request.sender_id, host, request.payload.clone()) + .forward_request(request.sender_connection_id, host, request.payload.clone()) .await?; let mut guests = self .store() .await - .read_project(project_id, request.sender_id)? + .read_project(project_id, request.sender_connection_id)? .connection_ids(); - guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id); + guests.retain(|guest_connection_id| *guest_connection_id != request.sender_connection_id); broadcast(host, guests, |conn_id| { self.peer .forward_send(host, conn_id, response_payload.clone()) @@ -1237,10 +1301,10 @@ impl Server { async fn create_buffer_for_peer( self: Arc, - request: TypedEnvelope, + request: Message, ) -> Result<()> { self.peer.forward_send( - request.sender_id, + request.sender_connection_id, ConnectionId(request.payload.peer_id), request.payload, )?; @@ -1249,76 +1313,101 @@ impl Server { async fn update_buffer( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); let receiver_ids = { let store = self.store().await; - store.project_connection_ids(project_id, request.sender_id)? + store.project_connection_ids(project_id, request.sender_connection_id)? }; - broadcast(request.sender_id, receiver_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); + broadcast( + request.sender_connection_id, + receiver_ids, + |connection_id| { + self.peer.forward_send( + request.sender_connection_id, + connection_id, + request.payload.clone(), + ) + }, + ); response.send(proto::Ack {})?; Ok(()) } async fn update_buffer_file( self: Arc, - request: TypedEnvelope, + request: Message, ) -> Result<()> { let receiver_ids = self.store().await.project_connection_ids( ProjectId::from_proto(request.payload.project_id), - request.sender_id, + request.sender_connection_id, )?; - broadcast(request.sender_id, receiver_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); + broadcast( + request.sender_connection_id, + receiver_ids, + |connection_id| { + self.peer.forward_send( + request.sender_connection_id, + connection_id, + request.payload.clone(), + ) + }, + ); Ok(()) } async fn buffer_reloaded( self: Arc, - request: TypedEnvelope, + request: Message, ) -> Result<()> { let receiver_ids = self.store().await.project_connection_ids( ProjectId::from_proto(request.payload.project_id), - request.sender_id, + request.sender_connection_id, )?; - broadcast(request.sender_id, receiver_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); + broadcast( + request.sender_connection_id, + receiver_ids, + |connection_id| { + self.peer.forward_send( + request.sender_connection_id, + connection_id, + request.payload.clone(), + ) + }, + ); Ok(()) } - async fn buffer_saved( - self: Arc, - request: TypedEnvelope, - ) -> Result<()> { + async fn buffer_saved(self: Arc, request: Message) -> Result<()> { let receiver_ids = self.store().await.project_connection_ids( ProjectId::from_proto(request.payload.project_id), - request.sender_id, + request.sender_connection_id, )?; - broadcast(request.sender_id, receiver_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); + broadcast( + request.sender_connection_id, + receiver_ids, + |connection_id| { + self.peer.forward_send( + request.sender_connection_id, + connection_id, + request.payload.clone(), + ) + }, + ); Ok(()) } async fn follow( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); let leader_id = ConnectionId(request.payload.leader_id); - let follower_id = request.sender_id; + let follower_id = request.sender_connection_id; { let store = self.store().await; if !store @@ -1331,7 +1420,7 @@ impl Server { let mut response_payload = self .peer - .forward_request(request.sender_id, leader_id, request.payload) + .forward_request(request.sender_connection_id, leader_id, request.payload) .await?; response_payload .views @@ -1340,28 +1429,29 @@ impl Server { Ok(()) } - async fn unfollow(self: Arc, request: TypedEnvelope) -> Result<()> { + async fn unfollow(self: Arc, request: Message) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); let leader_id = ConnectionId(request.payload.leader_id); let store = self.store().await; if !store - .project_connection_ids(project_id, request.sender_id)? + .project_connection_ids(project_id, request.sender_connection_id)? .contains(&leader_id) { Err(anyhow!("no such peer"))?; } self.peer - .forward_send(request.sender_id, leader_id, request.payload)?; + .forward_send(request.sender_connection_id, leader_id, request.payload)?; Ok(()) } async fn update_followers( self: Arc, - request: TypedEnvelope, + request: Message, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); let store = self.store().await; - let connection_ids = store.project_connection_ids(project_id, request.sender_id)?; + let connection_ids = + store.project_connection_ids(project_id, request.sender_connection_id)?; let leader_id = request .payload .variant @@ -1374,8 +1464,11 @@ impl Server { for follower_id in &request.payload.follower_ids { let follower_id = ConnectionId(*follower_id); if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id { - self.peer - .forward_send(request.sender_id, follower_id, request.payload.clone())?; + self.peer.forward_send( + request.sender_connection_id, + follower_id, + request.payload.clone(), + )?; } } Ok(()) @@ -1383,7 +1476,7 @@ impl Server { async fn get_users( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { let user_ids = request @@ -1410,13 +1503,9 @@ impl Server { async fn fuzzy_search_users( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { - let user_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; let query = request.payload.query; let db = &self.app_state.db; let users = match query.len() { @@ -1430,7 +1519,7 @@ impl Server { }; let users = users .into_iter() - .filter(|user| user.id != user_id) + .filter(|user| user.id != request.sender_user_id) .map(|user| proto::User { id: user.id.to_proto(), avatar_url: format!("https://github.com/{}.png?size=128", user.github_login), @@ -1443,13 +1532,10 @@ impl Server { async fn request_contact( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { - let requester_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; + let requester_id = request.sender_user_id; let responder_id = UserId::from_proto(request.payload.responder_id); if requester_id == responder_id { return Err(anyhow!("cannot add yourself as a contact"))?; @@ -1485,13 +1571,10 @@ impl Server { async fn respond_to_contact_request( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { - let responder_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; + let responder_id = request.sender_user_id; let requester_id = UserId::from_proto(request.payload.requester_id); if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 { self.app_state @@ -1541,13 +1624,10 @@ impl Server { async fn remove_contact( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { - let requester_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; + let requester_id = request.sender_user_id; let responder_id = UserId::from_proto(request.payload.user_id); self.app_state .db @@ -1578,33 +1658,40 @@ impl Server { async fn update_diff_base( self: Arc, - request: TypedEnvelope, + request: Message, ) -> Result<()> { let receiver_ids = self.store().await.project_connection_ids( ProjectId::from_proto(request.payload.project_id), - request.sender_id, + request.sender_connection_id, )?; - broadcast(request.sender_id, receiver_ids, |connection_id| { - self.peer - .forward_send(request.sender_id, connection_id, request.payload.clone()) - }); + broadcast( + request.sender_connection_id, + receiver_ids, + |connection_id| { + self.peer.forward_send( + request.sender_connection_id, + connection_id, + request.payload.clone(), + ) + }, + ); Ok(()) } async fn get_private_user_info( self: Arc, - request: TypedEnvelope, + request: Message, response: Response, ) -> Result<()> { - let user_id = self - .store() - .await - .user_id_for_connection(request.sender_id)?; - let metrics_id = self.app_state.db.get_user_metrics_id(user_id).await?; + let metrics_id = self + .app_state + .db + .get_user_metrics_id(request.sender_user_id) + .await?; let user = self .app_state .db - .get_user_by_id(user_id) + .get_user_by_id(request.sender_user_id) .await? .ok_or_else(|| anyhow!("user not found"))?; response.send(proto::GetPrivateUserInfoResponse { From 6871bbbc718d8d60951712f03462ce9c69d20c4a Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 11 Nov 2022 12:06:43 +0100 Subject: [PATCH 003/109] Start moving `Store` state into the database --- crates/call/src/call.rs | 20 +- crates/call/src/room.rs | 8 +- .../20221109000000_test_schema.sql | 2 +- .../20221111092550_reconnection_support.sql | 2 +- crates/collab/src/db.rs | 354 +++++++++++++++++- crates/collab/src/integration_tests.rs | 14 +- crates/collab/src/rpc.rs | 115 +++--- crates/collab/src/rpc/store.rs | 248 +----------- .../src/incoming_call_notification.rs | 6 +- crates/rpc/proto/zed.proto | 13 +- crates/rpc/src/rpc.rs | 2 +- 11 files changed, 447 insertions(+), 337 deletions(-) diff --git a/crates/call/src/call.rs b/crates/call/src/call.rs index 6b72eb61da37398036869ed0a1f554442a3b52ae..803fbb906adc53ac03cb1826a1f139931a83f8e1 100644 --- a/crates/call/src/call.rs +++ b/crates/call/src/call.rs @@ -22,7 +22,7 @@ pub fn init(client: Arc, user_store: ModelHandle, cx: &mut Mu #[derive(Clone)] pub struct IncomingCall { pub room_id: u64, - pub caller: Arc, + pub calling_user: Arc, pub participants: Vec>, pub initial_project: Option, } @@ -78,9 +78,9 @@ impl ActiveCall { user_store.get_users(envelope.payload.participant_user_ids, cx) }) .await?, - caller: user_store + calling_user: user_store .update(&mut cx, |user_store, cx| { - user_store.get_user(envelope.payload.caller_user_id, cx) + user_store.get_user(envelope.payload.calling_user_id, cx) }) .await?, initial_project: envelope.payload.initial_project, @@ -110,13 +110,13 @@ impl ActiveCall { pub fn invite( &mut self, - recipient_user_id: u64, + called_user_id: u64, initial_project: Option>, cx: &mut ModelContext, ) -> Task> { let client = self.client.clone(); let user_store = self.user_store.clone(); - if !self.pending_invites.insert(recipient_user_id) { + if !self.pending_invites.insert(called_user_id) { return Task::ready(Err(anyhow!("user was already invited"))); } @@ -136,13 +136,13 @@ impl ActiveCall { }; room.update(&mut cx, |room, cx| { - room.call(recipient_user_id, initial_project_id, cx) + room.call(called_user_id, initial_project_id, cx) }) .await?; } else { let room = cx .update(|cx| { - Room::create(recipient_user_id, initial_project, client, user_store, cx) + Room::create(called_user_id, initial_project, client, user_store, cx) }) .await?; @@ -155,7 +155,7 @@ impl ActiveCall { let result = invite.await; this.update(&mut cx, |this, cx| { - this.pending_invites.remove(&recipient_user_id); + this.pending_invites.remove(&called_user_id); cx.notify(); }); result @@ -164,7 +164,7 @@ impl ActiveCall { pub fn cancel_invite( &mut self, - recipient_user_id: u64, + called_user_id: u64, cx: &mut ModelContext, ) -> Task> { let room_id = if let Some(room) = self.room() { @@ -178,7 +178,7 @@ impl ActiveCall { client .request(proto::CancelCall { room_id, - recipient_user_id, + called_user_id, }) .await?; anyhow::Ok(()) diff --git a/crates/call/src/room.rs b/crates/call/src/room.rs index 7d5153950d76d16f3e3185835eed22ac430fda97..3e55dc4ce96d2cd594929da1be5d4507ba183b42 100644 --- a/crates/call/src/room.rs +++ b/crates/call/src/room.rs @@ -149,7 +149,7 @@ impl Room { } pub(crate) fn create( - recipient_user_id: u64, + called_user_id: u64, initial_project: Option>, client: Arc, user_store: ModelHandle, @@ -182,7 +182,7 @@ impl Room { match room .update(&mut cx, |room, cx| { room.leave_when_empty = true; - room.call(recipient_user_id, initial_project_id, cx) + room.call(called_user_id, initial_project_id, cx) }) .await { @@ -487,7 +487,7 @@ impl Room { pub(crate) fn call( &mut self, - recipient_user_id: u64, + called_user_id: u64, initial_project_id: Option, cx: &mut ModelContext, ) -> Task> { @@ -503,7 +503,7 @@ impl Room { let result = client .request(proto::Call { room_id, - recipient_user_id, + called_user_id, initial_project_id, }) .await; diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 731910027e1ef362e0cbe36ceec7ca0e7f5c0f88..93026575230fc3a6ea6ce1e865c77ef10fe2fc6f 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -82,4 +82,4 @@ CREATE TABLE "calls" ( "answering_connection_id" INTEGER, "initial_project_id" INTEGER REFERENCES projects (id) ); -CREATE UNIQUE INDEX "index_calls_on_calling_user_id" ON "calls" ("calling_user_id"); +CREATE UNIQUE INDEX "index_calls_on_called_user_id" ON "calls" ("called_user_id"); diff --git a/crates/collab/migrations/20221111092550_reconnection_support.sql b/crates/collab/migrations/20221111092550_reconnection_support.sql index 9474beff4296215588344b22368f7aecdd36006a..8f932acff3ff19857298137adf52b9000f2b3d0f 100644 --- a/crates/collab/migrations/20221111092550_reconnection_support.sql +++ b/crates/collab/migrations/20221111092550_reconnection_support.sql @@ -44,4 +44,4 @@ CREATE TABLE IF NOT EXISTS "calls" ( "answering_connection_id" INTEGER, "initial_project_id" INTEGER REFERENCES projects (id) ); -CREATE UNIQUE INDEX "index_calls_on_calling_user_id" ON "calls" ("calling_user_id"); +CREATE UNIQUE INDEX "index_calls_on_called_user_id" ON "calls" ("called_user_id"); diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 10da609d57b9b7cfe04927b681b378c07e099b4b..b7d6f995b0b5a595114c6582371f31816542863d 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -3,6 +3,7 @@ use anyhow::anyhow; use axum::http::StatusCode; use collections::HashMap; use futures::StreamExt; +use rpc::{proto, ConnectionId}; use serde::{Deserialize, Serialize}; use sqlx::{ migrate::{Migrate as _, Migration, MigrationSource}, @@ -565,6 +566,7 @@ where for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, + for<'a> Option: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>, @@ -882,42 +884,352 @@ where }) } - // projects - - /// Registers a new project for the given user. - pub async fn register_project(&self, host_user_id: UserId) -> Result { + pub async fn create_room( + &self, + user_id: UserId, + connection_id: ConnectionId, + ) -> Result { test_support!(self, { - Ok(sqlx::query_scalar( + let mut tx = self.pool.begin().await?; + let live_kit_room = nanoid::nanoid!(30); + let room_id = sqlx::query_scalar( " - INSERT INTO projects(host_user_id) - VALUES ($1) + INSERT INTO rooms (live_kit_room, version) + VALUES ($1, $2) RETURNING id ", ) - .bind(host_user_id) - .fetch_one(&self.pool) + .bind(&live_kit_room) + .bind(0) + .fetch_one(&mut tx) .await - .map(ProjectId)?) + .map(RoomId)?; + + sqlx::query( + " + INSERT INTO room_participants (room_id, user_id, connection_id) + VALUES ($1, $2, $3) + ", + ) + .bind(room_id) + .bind(user_id) + .bind(connection_id.0 as i32) + .execute(&mut tx) + .await?; + + sqlx::query( + " + INSERT INTO calls (room_id, calling_user_id, called_user_id, answering_connection_id) + VALUES ($1, $2, $3, $4) + ", + ) + .bind(room_id) + .bind(user_id) + .bind(user_id) + .bind(connection_id.0 as i32) + .execute(&mut tx) + .await?; + + self.commit_room_transaction(room_id, tx).await }) } - /// Unregisters a project for the given project id. - pub async fn unregister_project(&self, project_id: ProjectId) -> Result<()> { + pub async fn call( + &self, + room_id: RoomId, + calling_user_id: UserId, + called_user_id: UserId, + initial_project_id: Option, + ) -> Result { test_support!(self, { + let mut tx = self.pool.begin().await?; sqlx::query( " - UPDATE projects - SET unregistered = TRUE - WHERE id = $1 + INSERT INTO calls (room_id, calling_user_id, called_user_id, initial_project_id) + VALUES ($1, $2, $3, $4) + ", + ) + .bind(room_id) + .bind(calling_user_id) + .bind(called_user_id) + .bind(initial_project_id) + .execute(&mut tx) + .await?; + + sqlx::query( + " + INSERT INTO room_participants (room_id, user_id) + VALUES ($1, $2) + ", + ) + .bind(room_id) + .bind(called_user_id) + .execute(&mut tx) + .await?; + + self.commit_room_transaction(room_id, tx).await + }) + } + + pub async fn call_failed( + &self, + room_id: RoomId, + called_user_id: UserId, + ) -> Result { + test_support!(self, { + let mut tx = self.pool.begin().await?; + sqlx::query( + " + DELETE FROM calls + WHERE room_id = $1 AND called_user_id = $2 + ", + ) + .bind(room_id) + .bind(called_user_id) + .execute(&mut tx) + .await?; + + sqlx::query( + " + DELETE FROM room_participants + WHERE room_id = $1 AND user_id = $2 + ", + ) + .bind(room_id) + .bind(called_user_id) + .execute(&mut tx) + .await?; + + self.commit_room_transaction(room_id, tx).await + }) + } + + pub async fn update_room_participant_location( + &self, + room_id: RoomId, + user_id: UserId, + location: proto::ParticipantLocation, + ) -> Result { + test_support!(self, { + let mut tx = self.pool.begin().await?; + + let location_kind; + let location_project_id; + match location + .variant + .ok_or_else(|| anyhow!("invalid location"))? + { + proto::participant_location::Variant::SharedProject(project) => { + location_kind = 0; + location_project_id = Some(ProjectId::from_proto(project.id)); + } + proto::participant_location::Variant::UnsharedProject(_) => { + location_kind = 1; + location_project_id = None; + } + proto::participant_location::Variant::External(_) => { + location_kind = 2; + location_project_id = None; + } + } + + sqlx::query( + " + UPDATE room_participants + SET location_kind = $1 AND location_project_id = $2 + WHERE room_id = $1 AND user_id = $2 + ", + ) + .bind(location_kind) + .bind(location_project_id) + .bind(room_id) + .bind(user_id) + .execute(&mut tx) + .await?; + + self.commit_room_transaction(room_id, tx).await + }) + } + + async fn commit_room_transaction( + &self, + room_id: RoomId, + mut tx: sqlx::Transaction<'_, D>, + ) -> Result { + sqlx::query( + " + UPDATE rooms + SET version = version + 1 + WHERE id = $1 + ", + ) + .bind(room_id) + .execute(&mut tx) + .await?; + + let room: Room = sqlx::query_as( + " + SELECT * + FROM rooms + WHERE id = $1 + ", + ) + .bind(room_id) + .fetch_one(&mut tx) + .await?; + + let mut db_participants = + sqlx::query_as::<_, (UserId, Option, Option, Option)>( + " + SELECT user_id, connection_id, location_kind, location_project_id + FROM room_participants + WHERE room_id = $1 + ", + ) + .bind(room_id) + .fetch(&mut tx); + + let mut participants = Vec::new(); + let mut pending_participant_user_ids = Vec::new(); + while let Some(participant) = db_participants.next().await { + let (user_id, connection_id, _location_kind, _location_project_id) = participant?; + if let Some(connection_id) = connection_id { + participants.push(proto::Participant { + user_id: user_id.to_proto(), + peer_id: connection_id as u32, + projects: Default::default(), + location: Some(proto::ParticipantLocation { + variant: Some(proto::participant_location::Variant::External( + Default::default(), + )), + }), + }); + } else { + pending_participant_user_ids.push(user_id.to_proto()); + } + } + drop(db_participants); + + for participant in &mut participants { + let mut entries = sqlx::query_as::<_, (ProjectId, String)>( + " + SELECT projects.id, worktrees.root_name + FROM projects + LEFT JOIN worktrees ON projects.id = worktrees.project_id + WHERE room_id = $1 AND host_user_id = $2 + ", + ) + .bind(room_id) + .fetch(&mut tx); + + let mut projects = HashMap::default(); + while let Some(entry) = entries.next().await { + let (project_id, worktree_root_name) = entry?; + let participant_project = + projects + .entry(project_id) + .or_insert(proto::ParticipantProject { + id: project_id.to_proto(), + worktree_root_names: Default::default(), + }); + participant_project + .worktree_root_names + .push(worktree_root_name); + } + + participant.projects = projects.into_values().collect(); + } + + tx.commit().await?; + + Ok(proto::Room { + id: room.id.to_proto(), + version: room.version as u64, + live_kit_room: room.live_kit_room, + participants, + pending_participant_user_ids, + }) + } + + // projects + + pub async fn share_project( + &self, + user_id: UserId, + connection_id: ConnectionId, + room_id: RoomId, + worktrees: &[proto::WorktreeMetadata], + ) -> Result<(ProjectId, proto::Room)> { + test_support!(self, { + let mut tx = self.pool.begin().await?; + let project_id = sqlx::query_scalar( + " + INSERT INTO projects (host_user_id, room_id) + VALUES ($1) + RETURNING id + ", + ) + .bind(user_id) + .bind(room_id) + .fetch_one(&mut tx) + .await + .map(ProjectId)?; + + for worktree in worktrees { + sqlx::query( + " + INSERT INTO worktrees (id, project_id, root_name) + ", + ) + .bind(worktree.id as i32) + .bind(project_id) + .bind(&worktree.root_name) + .execute(&mut tx) + .await?; + } + + sqlx::query( + " + INSERT INTO project_collaborators ( + project_id, + connection_id, + user_id, + replica_id, + is_host + ) + VALUES ($1, $2, $3, $4, $5) ", ) .bind(project_id) - .execute(&self.pool) + .bind(connection_id.0 as i32) + .bind(user_id) + .bind(0) + .bind(true) + .execute(&mut tx) .await?; - Ok(()) + + let room = self.commit_room_transaction(room_id, tx).await?; + Ok((project_id, room)) }) } + pub async fn unshare_project(&self, project_id: ProjectId) -> Result<()> { + todo!() + // test_support!(self, { + // sqlx::query( + // " + // UPDATE projects + // SET unregistered = TRUE + // WHERE id = $1 + // ", + // ) + // .bind(project_id) + // .execute(&self.pool) + // .await?; + // Ok(()) + // }) + } + // contacts pub async fn get_contacts(&self, user_id: UserId) -> Result> { @@ -1246,6 +1558,14 @@ pub struct User { pub connected_once: bool, } +id_type!(RoomId); +#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)] +pub struct Room { + pub id: RoomId, + pub version: i32, + pub live_kit_room: String, +} + id_type!(ProjectId); #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)] pub struct Project { diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index 0a6c01a691aad8d4d9aad5555ea2b8937c626e86..6d3cff1718e983812539374373b540ef5ba6f27f 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -104,7 +104,7 @@ async fn test_basic_calls( // User B receives the call. let mut incoming_call_b = active_call_b.read_with(cx_b, |call, _| call.incoming()); let call_b = incoming_call_b.next().await.unwrap().unwrap(); - assert_eq!(call_b.caller.github_login, "user_a"); + assert_eq!(call_b.calling_user.github_login, "user_a"); // User B connects via another client and also receives a ring on the newly-connected client. let _client_b2 = server.create_client(cx_b2, "user_b").await; @@ -112,7 +112,7 @@ async fn test_basic_calls( let mut incoming_call_b2 = active_call_b2.read_with(cx_b2, |call, _| call.incoming()); deterministic.run_until_parked(); let call_b2 = incoming_call_b2.next().await.unwrap().unwrap(); - assert_eq!(call_b2.caller.github_login, "user_a"); + assert_eq!(call_b2.calling_user.github_login, "user_a"); // User B joins the room using the first client. active_call_b @@ -165,7 +165,7 @@ async fn test_basic_calls( // User C receives the call, but declines it. let call_c = incoming_call_c.next().await.unwrap().unwrap(); - assert_eq!(call_c.caller.github_login, "user_b"); + assert_eq!(call_c.calling_user.github_login, "user_b"); active_call_c.update(cx_c, |call, _| call.decline_incoming().unwrap()); assert!(incoming_call_c.next().await.unwrap().is_none()); @@ -308,7 +308,7 @@ async fn test_room_uniqueness( // User B receives the call from user A. let mut incoming_call_b = active_call_b.read_with(cx_b, |call, _| call.incoming()); let call_b1 = incoming_call_b.next().await.unwrap().unwrap(); - assert_eq!(call_b1.caller.github_login, "user_a"); + assert_eq!(call_b1.calling_user.github_login, "user_a"); // Ensure calling users A and B from client C fails. active_call_c @@ -367,7 +367,7 @@ async fn test_room_uniqueness( .unwrap(); deterministic.run_until_parked(); let call_b2 = incoming_call_b.next().await.unwrap().unwrap(); - assert_eq!(call_b2.caller.github_login, "user_c"); + assert_eq!(call_b2.calling_user.github_login, "user_c"); } #[gpui::test(iterations = 10)] @@ -695,7 +695,7 @@ async fn test_share_project( let incoming_call_b = active_call_b.read_with(cx_b, |call, _| call.incoming()); deterministic.run_until_parked(); let call = incoming_call_b.borrow().clone().unwrap(); - assert_eq!(call.caller.github_login, "user_a"); + assert_eq!(call.calling_user.github_login, "user_a"); let initial_project = call.initial_project.unwrap(); active_call_b .update(cx_b, |call, cx| call.accept_incoming(cx)) @@ -766,7 +766,7 @@ async fn test_share_project( let incoming_call_c = active_call_c.read_with(cx_c, |call, _| call.incoming()); deterministic.run_until_parked(); let call = incoming_call_c.borrow().clone().unwrap(); - assert_eq!(call.caller.github_login, "user_b"); + assert_eq!(call.calling_user.github_login, "user_b"); let initial_project = call.initial_project.unwrap(); active_call_c .update(cx_c, |call, cx| call.accept_incoming(cx)) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 757c765838551666613445121bfc4625ad89a2e6..75ff703b1f6c3d283be78c85791d8f7a86977097 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -2,7 +2,7 @@ mod store; use crate::{ auth, - db::{self, ProjectId, User, UserId}, + db::{self, ProjectId, RoomId, User, UserId}, AppState, Result, }; use anyhow::anyhow; @@ -486,7 +486,7 @@ impl Server { for project_id in projects_to_unshare { self.app_state .db - .unregister_project(project_id) + .unshare_project(project_id) .await .trace_err(); } @@ -559,11 +559,11 @@ impl Server { request: Message, response: Response, ) -> Result<()> { - let room; - { - let mut store = self.store().await; - room = store.create_room(request.sender_connection_id)?.clone(); - } + let room = self + .app_state + .db + .create_room(request.sender_user_id, request.sender_connection_id) + .await?; let live_kit_connection_info = if let Some(live_kit) = self.app_state.live_kit_client.as_ref() { @@ -710,8 +710,9 @@ impl Server { request: Message, response: Response, ) -> Result<()> { - let caller_user_id = request.sender_user_id; - let recipient_user_id = UserId::from_proto(request.payload.recipient_user_id); + let room_id = RoomId::from_proto(request.payload.room_id); + let calling_user_id = request.sender_user_id; + let called_user_id = UserId::from_proto(request.payload.called_user_id); let initial_project_id = request .payload .initial_project_id @@ -719,31 +720,44 @@ impl Server { if !self .app_state .db - .has_contact(caller_user_id, recipient_user_id) + .has_contact(calling_user_id, called_user_id) .await? { return Err(anyhow!("cannot call a user who isn't a contact"))?; } - let room_id = request.payload.room_id; - let mut calls = { - let mut store = self.store().await; - let (room, recipient_connection_ids, incoming_call) = store.call( - room_id, - recipient_user_id, - initial_project_id, - request.sender_connection_id, - )?; - self.room_updated(room); - recipient_connection_ids - .into_iter() - .map(|recipient_connection_id| { - self.peer - .request(recipient_connection_id, incoming_call.clone()) - }) - .collect::>() + let room = self + .app_state + .db + .call(room_id, calling_user_id, called_user_id, initial_project_id) + .await?; + self.room_updated(&room); + self.update_user_contacts(called_user_id).await?; + + let incoming_call = proto::IncomingCall { + room_id: room_id.to_proto(), + calling_user_id: calling_user_id.to_proto(), + participant_user_ids: room + .participants + .iter() + .map(|participant| participant.user_id) + .collect(), + initial_project: room.participants.iter().find_map(|participant| { + let initial_project_id = initial_project_id?.to_proto(); + participant + .projects + .iter() + .find(|project| project.id == initial_project_id) + .cloned() + }), }; - self.update_user_contacts(recipient_user_id).await?; + + let mut calls = self + .store() + .await + .connection_ids_for_user(called_user_id) + .map(|connection_id| self.peer.request(connection_id, incoming_call.clone())) + .collect::>(); while let Some(call_response) = calls.next().await { match call_response.as_ref() { @@ -757,12 +771,13 @@ impl Server { } } - { - let mut store = self.store().await; - let room = store.call_failed(room_id, recipient_user_id)?; - self.room_updated(&room); - } - self.update_user_contacts(recipient_user_id).await?; + let room = self + .app_state + .db + .call_failed(room_id, called_user_id) + .await?; + self.room_updated(&room); + self.update_user_contacts(called_user_id).await?; Err(anyhow!("failed to ring call recipient"))? } @@ -772,7 +787,7 @@ impl Server { request: Message, response: Response, ) -> Result<()> { - let recipient_user_id = UserId::from_proto(request.payload.recipient_user_id); + let recipient_user_id = UserId::from_proto(request.payload.called_user_id); { let mut store = self.store().await; let (room, recipient_connection_ids) = store.cancel_call( @@ -814,15 +829,17 @@ impl Server { request: Message, response: Response, ) -> Result<()> { - let room_id = request.payload.room_id; + let room_id = RoomId::from_proto(request.payload.room_id); let location = request .payload .location .ok_or_else(|| anyhow!("invalid location"))?; - let mut store = self.store().await; - let room = - store.update_participant_location(room_id, location, request.sender_connection_id)?; - self.room_updated(room); + let room = self + .app_state + .db + .update_room_participant_location(room_id, request.sender_user_id, location) + .await?; + self.room_updated(&room); response.send(proto::Ack {})?; Ok(()) } @@ -868,22 +885,20 @@ impl Server { request: Message, response: Response, ) -> Result<()> { - let project_id = self + let (project_id, room) = self .app_state .db - .register_project(request.sender_user_id) + .share_project( + request.sender_user_id, + request.sender_connection_id, + RoomId::from_proto(request.payload.room_id), + &request.payload.worktrees, + ) .await?; - let mut store = self.store().await; - let room = store.share_project( - request.payload.room_id, - project_id, - request.payload.worktrees, - request.sender_connection_id, - )?; response.send(proto::ShareProjectResponse { project_id: project_id.to_proto(), })?; - self.room_updated(room); + self.room_updated(&room); Ok(()) } diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index 81ef594ccd75b4098ec48af7f1c8a93b260d523b..72da82ea8ce1c6a8ab5539531467de2a6296c2bc 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -1,12 +1,10 @@ use crate::db::{self, ProjectId, UserId}; use anyhow::{anyhow, Result}; use collections::{btree_map, BTreeMap, BTreeSet, HashMap, HashSet}; -use nanoid::nanoid; use rpc::{proto, ConnectionId}; use serde::Serialize; use std::{borrow::Cow, mem, path::PathBuf, str}; use tracing::instrument; -use util::post_inc; pub type RoomId = u64; @@ -34,7 +32,7 @@ struct ConnectionState { #[derive(Copy, Clone, Eq, PartialEq, Serialize)] pub struct Call { - pub caller_user_id: UserId, + pub calling_user_id: UserId, pub room_id: RoomId, pub connection_id: Option, pub initial_project_id: Option, @@ -147,7 +145,7 @@ impl Store { let room = self.room(active_call.room_id)?; Some(proto::IncomingCall { room_id: active_call.room_id, - caller_user_id: active_call.caller_user_id.to_proto(), + calling_user_id: active_call.calling_user_id.to_proto(), participant_user_ids: room .participants .iter() @@ -285,47 +283,6 @@ impl Store { } } - pub fn create_room(&mut self, creator_connection_id: ConnectionId) -> Result<&proto::Room> { - let connection = self - .connections - .get_mut(&creator_connection_id) - .ok_or_else(|| anyhow!("no such connection"))?; - let connected_user = self - .connected_users - .get_mut(&connection.user_id) - .ok_or_else(|| anyhow!("no such connection"))?; - anyhow::ensure!( - connected_user.active_call.is_none(), - "can't create a room with an active call" - ); - - let room_id = post_inc(&mut self.next_room_id); - let room = proto::Room { - id: room_id, - participants: vec![proto::Participant { - user_id: connection.user_id.to_proto(), - peer_id: creator_connection_id.0, - projects: Default::default(), - location: Some(proto::ParticipantLocation { - variant: Some(proto::participant_location::Variant::External( - proto::participant_location::External {}, - )), - }), - }], - pending_participant_user_ids: Default::default(), - live_kit_room: nanoid!(30), - }; - - self.rooms.insert(room_id, room); - connected_user.active_call = Some(Call { - caller_user_id: connection.user_id, - room_id, - connection_id: Some(creator_connection_id), - initial_project_id: None, - }); - Ok(self.rooms.get(&room_id).unwrap()) - } - pub fn join_room( &mut self, room_id: RoomId, @@ -424,7 +381,7 @@ impl Store { .get_mut(&UserId::from_proto(*pending_participant_user_id)) { if let Some(call) = connected_user.active_call.as_ref() { - if call.caller_user_id == user_id { + if call.calling_user_id == user_id { connected_user.active_call.take(); canceled_call_connection_ids .extend(connected_user.connection_ids.iter().copied()); @@ -462,101 +419,10 @@ impl Store { &self.rooms } - pub fn call( - &mut self, - room_id: RoomId, - recipient_user_id: UserId, - initial_project_id: Option, - from_connection_id: ConnectionId, - ) -> Result<(&proto::Room, Vec, proto::IncomingCall)> { - let caller_user_id = self.user_id_for_connection(from_connection_id)?; - - let recipient_connection_ids = self - .connection_ids_for_user(recipient_user_id) - .collect::>(); - let mut recipient = self - .connected_users - .get_mut(&recipient_user_id) - .ok_or_else(|| anyhow!("no such connection"))?; - anyhow::ensure!( - recipient.active_call.is_none(), - "recipient is already on another call" - ); - - let room = self - .rooms - .get_mut(&room_id) - .ok_or_else(|| anyhow!("no such room"))?; - anyhow::ensure!( - room.participants - .iter() - .any(|participant| participant.peer_id == from_connection_id.0), - "no such room" - ); - anyhow::ensure!( - room.pending_participant_user_ids - .iter() - .all(|user_id| UserId::from_proto(*user_id) != recipient_user_id), - "cannot call the same user more than once" - ); - room.pending_participant_user_ids - .push(recipient_user_id.to_proto()); - - if let Some(initial_project_id) = initial_project_id { - let project = self - .projects - .get(&initial_project_id) - .ok_or_else(|| anyhow!("no such project"))?; - anyhow::ensure!(project.room_id == room_id, "no such project"); - } - - recipient.active_call = Some(Call { - caller_user_id, - room_id, - connection_id: None, - initial_project_id, - }); - - Ok(( - room, - recipient_connection_ids, - proto::IncomingCall { - room_id, - caller_user_id: caller_user_id.to_proto(), - participant_user_ids: room - .participants - .iter() - .map(|participant| participant.user_id) - .collect(), - initial_project: initial_project_id - .and_then(|id| Self::build_participant_project(id, &self.projects)), - }, - )) - } - - pub fn call_failed(&mut self, room_id: RoomId, to_user_id: UserId) -> Result<&proto::Room> { - let mut recipient = self - .connected_users - .get_mut(&to_user_id) - .ok_or_else(|| anyhow!("no such connection"))?; - anyhow::ensure!(recipient - .active_call - .map_or(false, |call| call.room_id == room_id - && call.connection_id.is_none())); - recipient.active_call = None; - let room = self - .rooms - .get_mut(&room_id) - .ok_or_else(|| anyhow!("no such room"))?; - room.pending_participant_user_ids - .retain(|user_id| UserId::from_proto(*user_id) != to_user_id); - Ok(room) - } - pub fn cancel_call( &mut self, room_id: RoomId, - recipient_user_id: UserId, + called_user_id: UserId, canceller_connection_id: ConnectionId, ) -> Result<(&proto::Room, HashSet)> { let canceller_user_id = self.user_id_for_connection(canceller_connection_id)?; @@ -566,7 +432,7 @@ impl Store { .ok_or_else(|| anyhow!("no such connection"))?; let recipient = self .connected_users - .get(&recipient_user_id) + .get(&called_user_id) .ok_or_else(|| anyhow!("no such connection"))?; let canceller_active_call = canceller .active_call @@ -595,9 +461,9 @@ impl Store { .get_mut(&room_id) .ok_or_else(|| anyhow!("no such room"))?; room.pending_participant_user_ids - .retain(|user_id| UserId::from_proto(*user_id) != recipient_user_id); + .retain(|user_id| UserId::from_proto(*user_id) != called_user_id); - let recipient = self.connected_users.get_mut(&recipient_user_id).unwrap(); + let recipient = self.connected_users.get_mut(&called_user_id).unwrap(); recipient.active_call.take(); Ok((room, recipient.connection_ids.clone())) @@ -608,10 +474,10 @@ impl Store { room_id: RoomId, recipient_connection_id: ConnectionId, ) -> Result<(&proto::Room, Vec)> { - let recipient_user_id = self.user_id_for_connection(recipient_connection_id)?; + let called_user_id = self.user_id_for_connection(recipient_connection_id)?; let recipient = self .connected_users - .get_mut(&recipient_user_id) + .get_mut(&called_user_id) .ok_or_else(|| anyhow!("no such connection"))?; if let Some(active_call) = recipient.active_call { anyhow::ensure!(active_call.room_id == room_id, "no such room"); @@ -621,112 +487,20 @@ impl Store { ); recipient.active_call.take(); let recipient_connection_ids = self - .connection_ids_for_user(recipient_user_id) + .connection_ids_for_user(called_user_id) .collect::>(); let room = self .rooms .get_mut(&active_call.room_id) .ok_or_else(|| anyhow!("no such room"))?; room.pending_participant_user_ids - .retain(|user_id| UserId::from_proto(*user_id) != recipient_user_id); + .retain(|user_id| UserId::from_proto(*user_id) != called_user_id); Ok((room, recipient_connection_ids)) } else { Err(anyhow!("user is not being called")) } } - pub fn update_participant_location( - &mut self, - room_id: RoomId, - location: proto::ParticipantLocation, - connection_id: ConnectionId, - ) -> Result<&proto::Room> { - let room = self - .rooms - .get_mut(&room_id) - .ok_or_else(|| anyhow!("no such room"))?; - if let Some(proto::participant_location::Variant::SharedProject(project)) = - location.variant.as_ref() - { - anyhow::ensure!( - room.participants - .iter() - .flat_map(|participant| &participant.projects) - .any(|participant_project| participant_project.id == project.id), - "no such project" - ); - } - - let participant = room - .participants - .iter_mut() - .find(|participant| participant.peer_id == connection_id.0) - .ok_or_else(|| anyhow!("no such room"))?; - participant.location = Some(location); - - Ok(room) - } - - pub fn share_project( - &mut self, - room_id: RoomId, - project_id: ProjectId, - worktrees: Vec, - host_connection_id: ConnectionId, - ) -> Result<&proto::Room> { - let connection = self - .connections - .get_mut(&host_connection_id) - .ok_or_else(|| anyhow!("no such connection"))?; - - let room = self - .rooms - .get_mut(&room_id) - .ok_or_else(|| anyhow!("no such room"))?; - let participant = room - .participants - .iter_mut() - .find(|participant| participant.peer_id == host_connection_id.0) - .ok_or_else(|| anyhow!("no such room"))?; - - connection.projects.insert(project_id); - self.projects.insert( - project_id, - Project { - id: project_id, - room_id, - host_connection_id, - host: Collaborator { - user_id: connection.user_id, - replica_id: 0, - admin: connection.admin, - }, - guests: Default::default(), - active_replica_ids: Default::default(), - worktrees: worktrees - .into_iter() - .map(|worktree| { - ( - worktree.id, - Worktree { - root_name: worktree.root_name, - visible: worktree.visible, - ..Default::default() - }, - ) - }) - .collect(), - language_servers: Default::default(), - }, - ); - - participant - .projects - .extend(Self::build_participant_project(project_id, &self.projects)); - - Ok(room) - } - pub fn unshare_project( &mut self, project_id: ProjectId, diff --git a/crates/collab_ui/src/incoming_call_notification.rs b/crates/collab_ui/src/incoming_call_notification.rs index e5c4b27d7e4abee3138d107c827b590261115331..a51fb4891d20ee303d35992ef1c2dbc298dd1562 100644 --- a/crates/collab_ui/src/incoming_call_notification.rs +++ b/crates/collab_ui/src/incoming_call_notification.rs @@ -74,7 +74,7 @@ impl IncomingCallNotification { let active_call = ActiveCall::global(cx); if action.accept { let join = active_call.update(cx, |active_call, cx| active_call.accept_incoming(cx)); - let caller_user_id = self.call.caller.id; + let caller_user_id = self.call.calling_user.id; let initial_project_id = self.call.initial_project.as_ref().map(|project| project.id); cx.spawn_weak(|_, mut cx| async move { join.await?; @@ -105,7 +105,7 @@ impl IncomingCallNotification { .as_ref() .unwrap_or(&default_project); Flex::row() - .with_children(self.call.caller.avatar.clone().map(|avatar| { + .with_children(self.call.calling_user.avatar.clone().map(|avatar| { Image::new(avatar) .with_style(theme.caller_avatar) .aligned() @@ -115,7 +115,7 @@ impl IncomingCallNotification { Flex::column() .with_child( Label::new( - self.call.caller.github_login.clone(), + self.call.calling_user.github_login.clone(), theme.caller_username.text.clone(), ) .contained() diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index ded708370d3f64d00c478661911e34e37fa8dd98..07e6fae3a81e6f2aaee0ec2678d553787c75d447 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -164,9 +164,10 @@ message LeaveRoom { message Room { uint64 id = 1; - repeated Participant participants = 2; - repeated uint64 pending_participant_user_ids = 3; - string live_kit_room = 4; + uint64 version = 2; + repeated Participant participants = 3; + repeated uint64 pending_participant_user_ids = 4; + string live_kit_room = 5; } message Participant { @@ -199,13 +200,13 @@ message ParticipantLocation { message Call { uint64 room_id = 1; - uint64 recipient_user_id = 2; + uint64 called_user_id = 2; optional uint64 initial_project_id = 3; } message IncomingCall { uint64 room_id = 1; - uint64 caller_user_id = 2; + uint64 calling_user_id = 2; repeated uint64 participant_user_ids = 3; optional ParticipantProject initial_project = 4; } @@ -214,7 +215,7 @@ message CallCanceled {} message CancelCall { uint64 room_id = 1; - uint64 recipient_user_id = 2; + uint64 called_user_id = 2; } message DeclineCall { diff --git a/crates/rpc/src/rpc.rs b/crates/rpc/src/rpc.rs index b6aef64677b6f06716a6ea40d9b52a42017c3543..5ca5711d9ca8c43cd5f1979ee76ea11e61053bec 100644 --- a/crates/rpc/src/rpc.rs +++ b/crates/rpc/src/rpc.rs @@ -6,4 +6,4 @@ pub use conn::Connection; pub use peer::*; mod macros; -pub const PROTOCOL_VERSION: u32 = 39; +pub const PROTOCOL_VERSION: u32 = 40; From 58947c5c7269ec5de2421cd018abe0d254626695 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 11 Nov 2022 14:28:26 +0100 Subject: [PATCH 004/109] Move incoming calls into `Db` --- crates/collab/src/db.rs | 89 +++++++++++++++++++++++++++++++--- crates/collab/src/rpc.rs | 31 +++--------- crates/collab/src/rpc/store.rs | 48 +----------------- 3 files changed, 89 insertions(+), 79 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index b7d6f995b0b5a595114c6582371f31816542863d..506606274d93e5d550888e130cb8915f222953e7 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -940,7 +940,7 @@ where calling_user_id: UserId, called_user_id: UserId, initial_project_id: Option, - ) -> Result { + ) -> Result<(proto::Room, proto::IncomingCall)> { test_support!(self, { let mut tx = self.pool.begin().await?; sqlx::query( @@ -967,10 +967,67 @@ where .execute(&mut tx) .await?; - self.commit_room_transaction(room_id, tx).await + let room = self.commit_room_transaction(room_id, tx).await?; + let incoming_call = + Self::build_incoming_call(&room, calling_user_id, initial_project_id); + Ok((room, incoming_call)) }) } + pub async fn incoming_call_for_user( + &self, + user_id: UserId, + ) -> Result> { + test_support!(self, { + let mut tx = self.pool.begin().await?; + let call = sqlx::query_as::<_, Call>( + " + SELECT * + FROM calls + WHERE called_user_id = $1 AND answering_connection_id IS NULL + ", + ) + .bind(user_id) + .fetch_optional(&mut tx) + .await?; + + if let Some(call) = call { + let room = self.get_room(call.room_id, &mut tx).await?; + Ok(Some(Self::build_incoming_call( + &room, + call.calling_user_id, + call.initial_project_id, + ))) + } else { + Ok(None) + } + }) + } + + fn build_incoming_call( + room: &proto::Room, + calling_user_id: UserId, + initial_project_id: Option, + ) -> proto::IncomingCall { + proto::IncomingCall { + room_id: room.id, + calling_user_id: calling_user_id.to_proto(), + participant_user_ids: room + .participants + .iter() + .map(|participant| participant.user_id) + .collect(), + initial_project: room.participants.iter().find_map(|participant| { + let initial_project_id = initial_project_id?.to_proto(); + participant + .projects + .iter() + .find(|project| project.id == initial_project_id) + .cloned() + }), + } + } + pub async fn call_failed( &self, room_id: RoomId, @@ -1066,7 +1123,17 @@ where .bind(room_id) .execute(&mut tx) .await?; + let room = self.get_room(room_id, &mut tx).await?; + tx.commit().await?; + + Ok(room) + } + async fn get_room( + &self, + room_id: RoomId, + tx: &mut sqlx::Transaction<'_, D>, + ) -> Result { let room: Room = sqlx::query_as( " SELECT * @@ -1075,7 +1142,7 @@ where ", ) .bind(room_id) - .fetch_one(&mut tx) + .fetch_one(&mut *tx) .await?; let mut db_participants = @@ -1087,7 +1154,7 @@ where ", ) .bind(room_id) - .fetch(&mut tx); + .fetch(&mut *tx); let mut participants = Vec::new(); let mut pending_participant_user_ids = Vec::new(); @@ -1120,7 +1187,7 @@ where ", ) .bind(room_id) - .fetch(&mut tx); + .fetch(&mut *tx); let mut projects = HashMap::default(); while let Some(entry) = entries.next().await { @@ -1139,9 +1206,6 @@ where participant.projects = projects.into_values().collect(); } - - tx.commit().await?; - Ok(proto::Room { id: room.id.to_proto(), version: room.version as u64, @@ -1566,6 +1630,15 @@ pub struct Room { pub live_kit_room: String, } +#[derive(Clone, Debug, Default, FromRow, PartialEq)] +pub struct Call { + pub room_id: RoomId, + pub calling_user_id: UserId, + pub called_user_id: UserId, + pub answering_connection_id: Option, + pub initial_project_id: Option, +} + id_type!(ProjectId); #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)] pub struct Project { diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 75ff703b1f6c3d283be78c85791d8f7a86977097..64affdb8252c0bce3dc318ecbc8e45b76fb5273d 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -346,11 +346,7 @@ impl Server { { let mut store = this.store().await; - let incoming_call = store.add_connection(connection_id, user_id, user.admin); - if let Some(incoming_call) = incoming_call { - this.peer.send(connection_id, incoming_call)?; - } - + store.add_connection(connection_id, user_id, user.admin); this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?; if let Some((code, count)) = invite_code { @@ -360,6 +356,11 @@ impl Server { })?; } } + + if let Some(incoming_call) = this.app_state.db.incoming_call_for_user(user_id).await? { + this.peer.send(connection_id, incoming_call)?; + } + this.update_user_contacts(user_id).await?; let handle_io = handle_io.fuse(); @@ -726,7 +727,7 @@ impl Server { return Err(anyhow!("cannot call a user who isn't a contact"))?; } - let room = self + let (room, incoming_call) = self .app_state .db .call(room_id, calling_user_id, called_user_id, initial_project_id) @@ -734,24 +735,6 @@ impl Server { self.room_updated(&room); self.update_user_contacts(called_user_id).await?; - let incoming_call = proto::IncomingCall { - room_id: room_id.to_proto(), - calling_user_id: calling_user_id.to_proto(), - participant_user_ids: room - .participants - .iter() - .map(|participant| participant.user_id) - .collect(), - initial_project: room.participants.iter().find_map(|participant| { - let initial_project_id = initial_project_id?.to_proto(); - participant - .projects - .iter() - .find(|project| project.id == initial_project_id) - .cloned() - }), - }; - let mut calls = self .store() .await diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index 72da82ea8ce1c6a8ab5539531467de2a6296c2bc..f16910fac514bc0def6b17cf5a5e2ff97e169557 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -122,12 +122,7 @@ impl Store { } #[instrument(skip(self))] - pub fn add_connection( - &mut self, - connection_id: ConnectionId, - user_id: UserId, - admin: bool, - ) -> Option { + pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId, admin: bool) { self.connections.insert( connection_id, ConnectionState { @@ -138,27 +133,6 @@ impl Store { ); let connected_user = self.connected_users.entry(user_id).or_default(); connected_user.connection_ids.insert(connection_id); - if let Some(active_call) = connected_user.active_call { - if active_call.connection_id.is_some() { - None - } else { - let room = self.room(active_call.room_id)?; - Some(proto::IncomingCall { - room_id: active_call.room_id, - calling_user_id: active_call.calling_user_id.to_proto(), - participant_user_ids: room - .participants - .iter() - .map(|participant| participant.user_id) - .collect(), - initial_project: active_call - .initial_project_id - .and_then(|id| Self::build_participant_project(id, &self.projects)), - }) - } - } else { - None - } } #[instrument(skip(self))] @@ -411,10 +385,6 @@ impl Store { }) } - pub fn room(&self, room_id: RoomId) -> Option<&proto::Room> { - self.rooms.get(&room_id) - } - pub fn rooms(&self) -> &BTreeMap { &self.rooms } @@ -740,22 +710,6 @@ impl Store { Ok(connection_ids) } - fn build_participant_project( - project_id: ProjectId, - projects: &BTreeMap, - ) -> Option { - Some(proto::ParticipantProject { - id: project_id.to_proto(), - worktree_root_names: projects - .get(&project_id)? - .worktrees - .values() - .filter(|worktree| worktree.visible) - .map(|worktree| worktree.root_name.clone()) - .collect(), - }) - } - pub fn project_connection_ids( &self, project_id: ProjectId, From cc58607c3b0d23d5907008d0f8eb1e9cfc0a8bab Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 11 Nov 2022 14:43:40 +0100 Subject: [PATCH 005/109] Move `Store::join_room` into `Db::join_room` --- crates/collab/src/db.rs | 85 ++++++++++++++++++++++++++++++++++ crates/collab/src/rpc.rs | 71 +++++++++++++++------------- crates/collab/src/rpc/store.rs | 51 -------------------- 3 files changed, 125 insertions(+), 82 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 506606274d93e5d550888e130cb8915f222953e7..7cc0dc35fe8689d7d46f88e290855556e7f0574a 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1061,6 +1061,91 @@ where }) } + pub async fn join_room( + &self, + room_id: RoomId, + user_id: UserId, + connection_id: ConnectionId, + ) -> Result { + test_support!(self, { + let mut tx = self.pool.begin().await?; + sqlx::query( + " + UPDATE calls + SET answering_connection_id = $1 + WHERE room_id = $2 AND called_user_id = $3 + RETURNING 1 + ", + ) + .bind(connection_id.0 as i32) + .bind(room_id) + .bind(user_id) + .fetch_one(&mut tx) + .await?; + + sqlx::query( + " + UPDATE room_participants + SET connection_id = $1 + WHERE room_id = $2 AND user_id = $3 + RETURNING 1 + ", + ) + .bind(connection_id.0 as i32) + .bind(room_id) + .bind(user_id) + .fetch_one(&mut tx) + .await?; + + self.commit_room_transaction(room_id, tx).await + }) + + // let connection = self + // .connections + // .get_mut(&connection_id) + // .ok_or_else(|| anyhow!("no such connection"))?; + // let user_id = connection.user_id; + // let recipient_connection_ids = self.connection_ids_for_user(user_id).collect::>(); + + // let connected_user = self + // .connected_users + // .get_mut(&user_id) + // .ok_or_else(|| anyhow!("no such connection"))?; + // let active_call = connected_user + // .active_call + // .as_mut() + // .ok_or_else(|| anyhow!("not being called"))?; + // anyhow::ensure!( + // active_call.room_id == room_id && active_call.connection_id.is_none(), + // "not being called on this room" + // ); + + // let room = self + // .rooms + // .get_mut(&room_id) + // .ok_or_else(|| anyhow!("no such room"))?; + // anyhow::ensure!( + // room.pending_participant_user_ids + // .contains(&user_id.to_proto()), + // anyhow!("no such room") + // ); + // room.pending_participant_user_ids + // .retain(|pending| *pending != user_id.to_proto()); + // room.participants.push(proto::Participant { + // user_id: user_id.to_proto(), + // peer_id: connection_id.0, + // projects: Default::default(), + // location: Some(proto::ParticipantLocation { + // variant: Some(proto::participant_location::Variant::External( + // proto::participant_location::External {}, + // )), + // }), + // }); + // active_call.connection_id = Some(connection_id); + + // Ok((room, recipient_connection_ids)) + } + pub async fn update_room_participant_location( &self, room_id: RoomId, diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 64affdb8252c0bce3dc318ecbc8e45b76fb5273d..c7c222ee1c30511f6bf296e62b502d5025292b2f 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -607,42 +607,51 @@ impl Server { request: Message, response: Response, ) -> Result<()> { + let room = self + .app_state + .db + .join_room( + RoomId::from_proto(request.payload.id), + request.sender_user_id, + request.sender_connection_id, + ) + .await?; + for recipient_id in self + .store() + .await + .connection_ids_for_user(request.sender_user_id) { - let mut store = self.store().await; - let (room, recipient_connection_ids) = - store.join_room(request.payload.id, request.sender_connection_id)?; - for recipient_id in recipient_connection_ids { - self.peer - .send(recipient_id, proto::CallCanceled {}) - .trace_err(); - } + self.peer + .send(recipient_id, proto::CallCanceled {}) + .trace_err(); + } - let live_kit_connection_info = - if let Some(live_kit) = self.app_state.live_kit_client.as_ref() { - if let Some(token) = live_kit - .room_token( - &room.live_kit_room, - &request.sender_connection_id.to_string(), - ) - .trace_err() - { - Some(proto::LiveKitConnectionInfo { - server_url: live_kit.url().into(), - token, - }) - } else { - None - } + let live_kit_connection_info = + if let Some(live_kit) = self.app_state.live_kit_client.as_ref() { + if let Some(token) = live_kit + .room_token( + &room.live_kit_room, + &request.sender_connection_id.to_string(), + ) + .trace_err() + { + Some(proto::LiveKitConnectionInfo { + server_url: live_kit.url().into(), + token, + }) } else { None - }; + } + } else { + None + }; + + self.room_updated(&room); + response.send(proto::JoinRoomResponse { + room: Some(room), + live_kit_connection_info, + })?; - response.send(proto::JoinRoomResponse { - room: Some(room.clone()), - live_kit_connection_info, - })?; - self.room_updated(room); - } self.update_user_contacts(request.sender_user_id).await?; Ok(()) } diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index f16910fac514bc0def6b17cf5a5e2ff97e169557..dfd534dbe9502414e273f9924f00764945e515a1 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -257,57 +257,6 @@ impl Store { } } - pub fn join_room( - &mut self, - room_id: RoomId, - connection_id: ConnectionId, - ) -> Result<(&proto::Room, Vec)> { - let connection = self - .connections - .get_mut(&connection_id) - .ok_or_else(|| anyhow!("no such connection"))?; - let user_id = connection.user_id; - let recipient_connection_ids = self.connection_ids_for_user(user_id).collect::>(); - - let connected_user = self - .connected_users - .get_mut(&user_id) - .ok_or_else(|| anyhow!("no such connection"))?; - let active_call = connected_user - .active_call - .as_mut() - .ok_or_else(|| anyhow!("not being called"))?; - anyhow::ensure!( - active_call.room_id == room_id && active_call.connection_id.is_none(), - "not being called on this room" - ); - - let room = self - .rooms - .get_mut(&room_id) - .ok_or_else(|| anyhow!("no such room"))?; - anyhow::ensure!( - room.pending_participant_user_ids - .contains(&user_id.to_proto()), - anyhow!("no such room") - ); - room.pending_participant_user_ids - .retain(|pending| *pending != user_id.to_proto()); - room.participants.push(proto::Participant { - user_id: user_id.to_proto(), - peer_id: connection_id.0, - projects: Default::default(), - location: Some(proto::ParticipantLocation { - variant: Some(proto::participant_location::Variant::External( - proto::participant_location::External {}, - )), - }), - }); - active_call.connection_id = Some(connection_id); - - Ok((room, recipient_connection_ids)) - } - pub fn leave_room(&mut self, room_id: RoomId, connection_id: ConnectionId) -> Result { let connection = self .connections From c213c98ea40dca5408f1f4250bc338dc49953905 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 11 Nov 2022 15:22:04 +0100 Subject: [PATCH 006/109] Remove `calls` table and use just `room_participants` --- crates/call/src/room.rs | 7 +- .../20221109000000_test_schema.sql | 16 +- .../20221111092550_reconnection_support.sql | 15 +- crates/collab/src/db.rs | 165 +++------- crates/collab/src/rpc/store.rs | 307 +++++++++--------- crates/rpc/proto/zed.proto | 8 +- 6 files changed, 217 insertions(+), 301 deletions(-) diff --git a/crates/call/src/room.rs b/crates/call/src/room.rs index 3e55dc4ce96d2cd594929da1be5d4507ba183b42..4f3079e72c1e75ab1cbd5818eeabf3151a8c21a1 100644 --- a/crates/call/src/room.rs +++ b/crates/call/src/room.rs @@ -294,6 +294,11 @@ impl Room { .position(|participant| Some(participant.user_id) == self.client.user_id()); let local_participant = local_participant_ix.map(|ix| room.participants.swap_remove(ix)); + let pending_participant_user_ids = room + .pending_participants + .iter() + .map(|p| p.user_id) + .collect::>(); let remote_participant_user_ids = room .participants .iter() @@ -303,7 +308,7 @@ impl Room { self.user_store.update(cx, move |user_store, cx| { ( user_store.get_users(remote_participant_user_ids, cx), - user_store.get_users(room.pending_participant_user_ids, cx), + user_store.get_users(pending_participant_user_ids, cx), ) }); self.pending_room_update = Some(cx.spawn(|this, mut cx| async move { diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 93026575230fc3a6ea6ce1e865c77ef10fe2fc6f..5b38ebf8b1e9d88a12c18064f973aed183cf045b 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -70,16 +70,8 @@ CREATE TABLE "room_participants" ( "user_id" INTEGER NOT NULL REFERENCES users (id), "connection_id" INTEGER, "location_kind" INTEGER, - "location_project_id" INTEGER REFERENCES projects (id) + "location_project_id" INTEGER REFERENCES projects (id), + "initial_project_id" INTEGER REFERENCES projects (id), + "calling_user_id" INTEGER NOT NULL REFERENCES users (id) ); -CREATE UNIQUE INDEX "index_room_participants_on_user_id_and_room_id" ON "room_participants" ("user_id", "room_id"); - -CREATE TABLE "calls" ( - "id" INTEGER PRIMARY KEY, - "room_id" INTEGER NOT NULL REFERENCES rooms (id), - "calling_user_id" INTEGER NOT NULL REFERENCES users (id), - "called_user_id" INTEGER NOT NULL REFERENCES users (id), - "answering_connection_id" INTEGER, - "initial_project_id" INTEGER REFERENCES projects (id) -); -CREATE UNIQUE INDEX "index_calls_on_called_user_id" ON "calls" ("called_user_id"); +CREATE UNIQUE INDEX "index_room_participants_on_user_id" ON "room_participants" ("user_id"); diff --git a/crates/collab/migrations/20221111092550_reconnection_support.sql b/crates/collab/migrations/20221111092550_reconnection_support.sql index 8f932acff3ff19857298137adf52b9000f2b3d0f..621512bf43b3c2f39ce54a478f8e82825fc37cfd 100644 --- a/crates/collab/migrations/20221111092550_reconnection_support.sql +++ b/crates/collab/migrations/20221111092550_reconnection_support.sql @@ -32,16 +32,9 @@ CREATE TABLE IF NOT EXISTS "room_participants" ( "user_id" INTEGER NOT NULL REFERENCES users (id), "connection_id" INTEGER, "location_kind" INTEGER, - "location_project_id" INTEGER REFERENCES projects (id) + "location_project_id" INTEGER REFERENCES projects (id), + "initial_project_id" INTEGER REFERENCES projects (id), + "calling_user_id" INTEGER NOT NULL REFERENCES users (id) ); -CREATE UNIQUE INDEX "index_room_participants_on_user_id_and_room_id" ON "room_participants" ("user_id", "room_id"); +CREATE UNIQUE INDEX "index_room_participants_on_user_id" ON "room_participants" ("user_id"); -CREATE TABLE IF NOT EXISTS "calls" ( - "id" SERIAL PRIMARY KEY, - "room_id" INTEGER NOT NULL REFERENCES rooms (id), - "calling_user_id" INTEGER NOT NULL REFERENCES users (id), - "called_user_id" INTEGER NOT NULL REFERENCES users (id), - "answering_connection_id" INTEGER, - "initial_project_id" INTEGER REFERENCES projects (id) -); -CREATE UNIQUE INDEX "index_calls_on_called_user_id" ON "calls" ("called_user_id"); diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 7cc0dc35fe8689d7d46f88e290855556e7f0574a..a98621d8942f54e0a92848d6c2bc8080ed998356 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -907,26 +907,14 @@ where sqlx::query( " - INSERT INTO room_participants (room_id, user_id, connection_id) - VALUES ($1, $2, $3) - ", - ) - .bind(room_id) - .bind(user_id) - .bind(connection_id.0 as i32) - .execute(&mut tx) - .await?; - - sqlx::query( - " - INSERT INTO calls (room_id, calling_user_id, called_user_id, answering_connection_id) + INSERT INTO room_participants (room_id, user_id, connection_id, calling_user_id) VALUES ($1, $2, $3, $4) ", ) .bind(room_id) .bind(user_id) - .bind(user_id) .bind(connection_id.0 as i32) + .bind(user_id) .execute(&mut tx) .await?; @@ -945,31 +933,20 @@ where let mut tx = self.pool.begin().await?; sqlx::query( " - INSERT INTO calls (room_id, calling_user_id, called_user_id, initial_project_id) + INSERT INTO room_participants (room_id, user_id, calling_user_id, initial_project_id) VALUES ($1, $2, $3, $4) ", ) .bind(room_id) - .bind(calling_user_id) .bind(called_user_id) + .bind(calling_user_id) .bind(initial_project_id) .execute(&mut tx) .await?; - sqlx::query( - " - INSERT INTO room_participants (room_id, user_id) - VALUES ($1, $2) - ", - ) - .bind(room_id) - .bind(called_user_id) - .execute(&mut tx) - .await?; - let room = self.commit_room_transaction(room_id, tx).await?; - let incoming_call = - Self::build_incoming_call(&room, calling_user_id, initial_project_id); + let incoming_call = Self::build_incoming_call(&room, called_user_id) + .ok_or_else(|| anyhow!("failed to build incoming call"))?; Ok((room, incoming_call)) }) } @@ -980,24 +957,20 @@ where ) -> Result> { test_support!(self, { let mut tx = self.pool.begin().await?; - let call = sqlx::query_as::<_, Call>( + let room_id = sqlx::query_scalar::<_, RoomId>( " - SELECT * - FROM calls - WHERE called_user_id = $1 AND answering_connection_id IS NULL + SELECT room_id + FROM room_participants + WHERE user_id = $1 AND connection_id IS NULL ", ) .bind(user_id) .fetch_optional(&mut tx) .await?; - if let Some(call) = call { - let room = self.get_room(call.room_id, &mut tx).await?; - Ok(Some(Self::build_incoming_call( - &room, - call.calling_user_id, - call.initial_project_id, - ))) + if let Some(room_id) = room_id { + let room = self.get_room(room_id, &mut tx).await?; + Ok(Self::build_incoming_call(&room, user_id)) } else { Ok(None) } @@ -1006,26 +979,30 @@ where fn build_incoming_call( room: &proto::Room, - calling_user_id: UserId, - initial_project_id: Option, - ) -> proto::IncomingCall { - proto::IncomingCall { + called_user_id: UserId, + ) -> Option { + let pending_participant = room + .pending_participants + .iter() + .find(|participant| participant.user_id == called_user_id.to_proto())?; + + Some(proto::IncomingCall { room_id: room.id, - calling_user_id: calling_user_id.to_proto(), + calling_user_id: pending_participant.calling_user_id, participant_user_ids: room .participants .iter() .map(|participant| participant.user_id) .collect(), initial_project: room.participants.iter().find_map(|participant| { - let initial_project_id = initial_project_id?.to_proto(); + let initial_project_id = pending_participant.initial_project_id?; participant .projects .iter() .find(|project| project.id == initial_project_id) .cloned() }), - } + }) } pub async fn call_failed( @@ -1035,17 +1012,6 @@ where ) -> Result { test_support!(self, { let mut tx = self.pool.begin().await?; - sqlx::query( - " - DELETE FROM calls - WHERE room_id = $1 AND called_user_id = $2 - ", - ) - .bind(room_id) - .bind(called_user_id) - .execute(&mut tx) - .await?; - sqlx::query( " DELETE FROM room_participants @@ -1069,20 +1035,6 @@ where ) -> Result { test_support!(self, { let mut tx = self.pool.begin().await?; - sqlx::query( - " - UPDATE calls - SET answering_connection_id = $1 - WHERE room_id = $2 AND called_user_id = $3 - RETURNING 1 - ", - ) - .bind(connection_id.0 as i32) - .bind(room_id) - .bind(user_id) - .fetch_one(&mut tx) - .await?; - sqlx::query( " UPDATE room_participants @@ -1096,54 +1048,8 @@ where .bind(user_id) .fetch_one(&mut tx) .await?; - self.commit_room_transaction(room_id, tx).await }) - - // let connection = self - // .connections - // .get_mut(&connection_id) - // .ok_or_else(|| anyhow!("no such connection"))?; - // let user_id = connection.user_id; - // let recipient_connection_ids = self.connection_ids_for_user(user_id).collect::>(); - - // let connected_user = self - // .connected_users - // .get_mut(&user_id) - // .ok_or_else(|| anyhow!("no such connection"))?; - // let active_call = connected_user - // .active_call - // .as_mut() - // .ok_or_else(|| anyhow!("not being called"))?; - // anyhow::ensure!( - // active_call.room_id == room_id && active_call.connection_id.is_none(), - // "not being called on this room" - // ); - - // let room = self - // .rooms - // .get_mut(&room_id) - // .ok_or_else(|| anyhow!("no such room"))?; - // anyhow::ensure!( - // room.pending_participant_user_ids - // .contains(&user_id.to_proto()), - // anyhow!("no such room") - // ); - // room.pending_participant_user_ids - // .retain(|pending| *pending != user_id.to_proto()); - // room.participants.push(proto::Participant { - // user_id: user_id.to_proto(), - // peer_id: connection_id.0, - // projects: Default::default(), - // location: Some(proto::ParticipantLocation { - // variant: Some(proto::participant_location::Variant::External( - // proto::participant_location::External {}, - // )), - // }), - // }); - // active_call.connection_id = Some(connection_id); - - // Ok((room, recipient_connection_ids)) } pub async fn update_room_participant_location( @@ -1231,9 +1137,9 @@ where .await?; let mut db_participants = - sqlx::query_as::<_, (UserId, Option, Option, Option)>( + sqlx::query_as::<_, (UserId, Option, Option, Option, UserId, Option)>( " - SELECT user_id, connection_id, location_kind, location_project_id + SELECT user_id, connection_id, location_kind, location_project_id, calling_user_id, initial_project_id FROM room_participants WHERE room_id = $1 ", @@ -1242,9 +1148,16 @@ where .fetch(&mut *tx); let mut participants = Vec::new(); - let mut pending_participant_user_ids = Vec::new(); + let mut pending_participants = Vec::new(); while let Some(participant) = db_participants.next().await { - let (user_id, connection_id, _location_kind, _location_project_id) = participant?; + let ( + user_id, + connection_id, + _location_kind, + _location_project_id, + calling_user_id, + initial_project_id, + ) = participant?; if let Some(connection_id) = connection_id { participants.push(proto::Participant { user_id: user_id.to_proto(), @@ -1257,7 +1170,11 @@ where }), }); } else { - pending_participant_user_ids.push(user_id.to_proto()); + pending_participants.push(proto::PendingParticipant { + user_id: user_id.to_proto(), + calling_user_id: calling_user_id.to_proto(), + initial_project_id: initial_project_id.map(|id| id.to_proto()), + }); } } drop(db_participants); @@ -1296,7 +1213,7 @@ where version: room.version as u64, live_kit_room: room.live_kit_room, participants, - pending_participant_user_ids, + pending_participants, }) } diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index dfd534dbe9502414e273f9924f00764945e515a1..610a653dc9841fdca10f672789aafb3735558f99 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -258,80 +258,81 @@ impl Store { } pub fn leave_room(&mut self, room_id: RoomId, connection_id: ConnectionId) -> Result { - let connection = self - .connections - .get_mut(&connection_id) - .ok_or_else(|| anyhow!("no such connection"))?; - let user_id = connection.user_id; - - let connected_user = self - .connected_users - .get(&user_id) - .ok_or_else(|| anyhow!("no such connection"))?; - anyhow::ensure!( - connected_user - .active_call - .map_or(false, |call| call.room_id == room_id - && call.connection_id == Some(connection_id)), - "cannot leave a room before joining it" - ); - - // Given that users can only join one room at a time, we can safely unshare - // and leave all projects associated with the connection. - let mut unshared_projects = Vec::new(); - let mut left_projects = Vec::new(); - for project_id in connection.projects.clone() { - if let Ok((_, project)) = self.unshare_project(project_id, connection_id) { - unshared_projects.push(project); - } else if let Ok(project) = self.leave_project(project_id, connection_id) { - left_projects.push(project); - } - } - self.connected_users.get_mut(&user_id).unwrap().active_call = None; - - let room = self - .rooms - .get_mut(&room_id) - .ok_or_else(|| anyhow!("no such room"))?; - room.participants - .retain(|participant| participant.peer_id != connection_id.0); - - let mut canceled_call_connection_ids = Vec::new(); - room.pending_participant_user_ids - .retain(|pending_participant_user_id| { - if let Some(connected_user) = self - .connected_users - .get_mut(&UserId::from_proto(*pending_participant_user_id)) - { - if let Some(call) = connected_user.active_call.as_ref() { - if call.calling_user_id == user_id { - connected_user.active_call.take(); - canceled_call_connection_ids - .extend(connected_user.connection_ids.iter().copied()); - false - } else { - true - } - } else { - true - } - } else { - true - } - }); - - let room = if room.participants.is_empty() { - Cow::Owned(self.rooms.remove(&room_id).unwrap()) - } else { - Cow::Borrowed(self.rooms.get(&room_id).unwrap()) - }; - - Ok(LeftRoom { - room, - unshared_projects, - left_projects, - canceled_call_connection_ids, - }) + todo!() + // let connection = self + // .connections + // .get_mut(&connection_id) + // .ok_or_else(|| anyhow!("no such connection"))?; + // let user_id = connection.user_id; + + // let connected_user = self + // .connected_users + // .get(&user_id) + // .ok_or_else(|| anyhow!("no such connection"))?; + // anyhow::ensure!( + // connected_user + // .active_call + // .map_or(false, |call| call.room_id == room_id + // && call.connection_id == Some(connection_id)), + // "cannot leave a room before joining it" + // ); + + // // Given that users can only join one room at a time, we can safely unshare + // // and leave all projects associated with the connection. + // let mut unshared_projects = Vec::new(); + // let mut left_projects = Vec::new(); + // for project_id in connection.projects.clone() { + // if let Ok((_, project)) = self.unshare_project(project_id, connection_id) { + // unshared_projects.push(project); + // } else if let Ok(project) = self.leave_project(project_id, connection_id) { + // left_projects.push(project); + // } + // } + // self.connected_users.get_mut(&user_id).unwrap().active_call = None; + + // let room = self + // .rooms + // .get_mut(&room_id) + // .ok_or_else(|| anyhow!("no such room"))?; + // room.participants + // .retain(|participant| participant.peer_id != connection_id.0); + + // let mut canceled_call_connection_ids = Vec::new(); + // room.pending_participant_user_ids + // .retain(|pending_participant_user_id| { + // if let Some(connected_user) = self + // .connected_users + // .get_mut(&UserId::from_proto(*pending_participant_user_id)) + // { + // if let Some(call) = connected_user.active_call.as_ref() { + // if call.calling_user_id == user_id { + // connected_user.active_call.take(); + // canceled_call_connection_ids + // .extend(connected_user.connection_ids.iter().copied()); + // false + // } else { + // true + // } + // } else { + // true + // } + // } else { + // true + // } + // }); + + // let room = if room.participants.is_empty() { + // Cow::Owned(self.rooms.remove(&room_id).unwrap()) + // } else { + // Cow::Borrowed(self.rooms.get(&room_id).unwrap()) + // }; + + // Ok(LeftRoom { + // room, + // unshared_projects, + // left_projects, + // canceled_call_connection_ids, + // }) } pub fn rooms(&self) -> &BTreeMap { @@ -344,48 +345,49 @@ impl Store { called_user_id: UserId, canceller_connection_id: ConnectionId, ) -> Result<(&proto::Room, HashSet)> { - let canceller_user_id = self.user_id_for_connection(canceller_connection_id)?; - let canceller = self - .connected_users - .get(&canceller_user_id) - .ok_or_else(|| anyhow!("no such connection"))?; - let recipient = self - .connected_users - .get(&called_user_id) - .ok_or_else(|| anyhow!("no such connection"))?; - let canceller_active_call = canceller - .active_call - .as_ref() - .ok_or_else(|| anyhow!("no active call"))?; - let recipient_active_call = recipient - .active_call - .as_ref() - .ok_or_else(|| anyhow!("no active call for recipient"))?; - - anyhow::ensure!( - canceller_active_call.room_id == room_id, - "users are on different calls" - ); - anyhow::ensure!( - recipient_active_call.room_id == room_id, - "users are on different calls" - ); - anyhow::ensure!( - recipient_active_call.connection_id.is_none(), - "recipient has already answered" - ); - let room_id = recipient_active_call.room_id; - let room = self - .rooms - .get_mut(&room_id) - .ok_or_else(|| anyhow!("no such room"))?; - room.pending_participant_user_ids - .retain(|user_id| UserId::from_proto(*user_id) != called_user_id); - - let recipient = self.connected_users.get_mut(&called_user_id).unwrap(); - recipient.active_call.take(); - - Ok((room, recipient.connection_ids.clone())) + todo!() + // let canceller_user_id = self.user_id_for_connection(canceller_connection_id)?; + // let canceller = self + // .connected_users + // .get(&canceller_user_id) + // .ok_or_else(|| anyhow!("no such connection"))?; + // let recipient = self + // .connected_users + // .get(&called_user_id) + // .ok_or_else(|| anyhow!("no such connection"))?; + // let canceller_active_call = canceller + // .active_call + // .as_ref() + // .ok_or_else(|| anyhow!("no active call"))?; + // let recipient_active_call = recipient + // .active_call + // .as_ref() + // .ok_or_else(|| anyhow!("no active call for recipient"))?; + + // anyhow::ensure!( + // canceller_active_call.room_id == room_id, + // "users are on different calls" + // ); + // anyhow::ensure!( + // recipient_active_call.room_id == room_id, + // "users are on different calls" + // ); + // anyhow::ensure!( + // recipient_active_call.connection_id.is_none(), + // "recipient has already answered" + // ); + // let room_id = recipient_active_call.room_id; + // let room = self + // .rooms + // .get_mut(&room_id) + // .ok_or_else(|| anyhow!("no such room"))?; + // room.pending_participant_user_ids + // .retain(|user_id| UserId::from_proto(*user_id) != called_user_id); + + // let recipient = self.connected_users.get_mut(&called_user_id).unwrap(); + // recipient.active_call.take(); + + // Ok((room, recipient.connection_ids.clone())) } pub fn decline_call( @@ -393,31 +395,32 @@ impl Store { room_id: RoomId, recipient_connection_id: ConnectionId, ) -> Result<(&proto::Room, Vec)> { - let called_user_id = self.user_id_for_connection(recipient_connection_id)?; - let recipient = self - .connected_users - .get_mut(&called_user_id) - .ok_or_else(|| anyhow!("no such connection"))?; - if let Some(active_call) = recipient.active_call { - anyhow::ensure!(active_call.room_id == room_id, "no such room"); - anyhow::ensure!( - active_call.connection_id.is_none(), - "cannot decline a call after joining room" - ); - recipient.active_call.take(); - let recipient_connection_ids = self - .connection_ids_for_user(called_user_id) - .collect::>(); - let room = self - .rooms - .get_mut(&active_call.room_id) - .ok_or_else(|| anyhow!("no such room"))?; - room.pending_participant_user_ids - .retain(|user_id| UserId::from_proto(*user_id) != called_user_id); - Ok((room, recipient_connection_ids)) - } else { - Err(anyhow!("user is not being called")) - } + todo!() + // let called_user_id = self.user_id_for_connection(recipient_connection_id)?; + // let recipient = self + // .connected_users + // .get_mut(&called_user_id) + // .ok_or_else(|| anyhow!("no such connection"))?; + // if let Some(active_call) = recipient.active_call { + // anyhow::ensure!(active_call.room_id == room_id, "no such room"); + // anyhow::ensure!( + // active_call.connection_id.is_none(), + // "cannot decline a call after joining room" + // ); + // recipient.active_call.take(); + // let recipient_connection_ids = self + // .connection_ids_for_user(called_user_id) + // .collect::>(); + // let room = self + // .rooms + // .get_mut(&active_call.room_id) + // .ok_or_else(|| anyhow!("no such room"))?; + // room.pending_participant_user_ids + // .retain(|user_id| UserId::from_proto(*user_id) != called_user_id); + // Ok((room, recipient_connection_ids)) + // } else { + // Err(anyhow!("user is not being called")) + // } } pub fn unshare_project( @@ -767,13 +770,13 @@ impl Store { } for (room_id, room) in &self.rooms { - for pending_user_id in &room.pending_participant_user_ids { - assert!( - self.connected_users - .contains_key(&UserId::from_proto(*pending_user_id)), - "call is active on a user that has disconnected" - ); - } + // for pending_user_id in &room.pending_participant_user_ids { + // assert!( + // self.connected_users + // .contains_key(&UserId::from_proto(*pending_user_id)), + // "call is active on a user that has disconnected" + // ); + // } for participant in &room.participants { assert!( @@ -793,10 +796,10 @@ impl Store { } } - assert!( - !room.pending_participant_user_ids.is_empty() || !room.participants.is_empty(), - "room can't be empty" - ); + // assert!( + // !room.pending_participant_user_ids.is_empty() || !room.participants.is_empty(), + // "room can't be empty" + // ); } for (project_id, project) in &self.projects { diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index 07e6fae3a81e6f2aaee0ec2678d553787c75d447..c1daf758230058b53c7929e8a5cf3859b4d9f81b 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -166,7 +166,7 @@ message Room { uint64 id = 1; uint64 version = 2; repeated Participant participants = 3; - repeated uint64 pending_participant_user_ids = 4; + repeated PendingParticipant pending_participants = 4; string live_kit_room = 5; } @@ -177,6 +177,12 @@ message Participant { ParticipantLocation location = 4; } +message PendingParticipant { + uint64 user_id = 1; + uint64 calling_user_id = 2; + optional uint64 initial_project_id = 3; +} + message ParticipantProject { uint64 id = 1; repeated string worktree_root_names = 2; From 0d1d267213b7494730dd9ae6abbdbb00e2bed34d Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 11 Nov 2022 15:41:56 +0100 Subject: [PATCH 007/109] Move `Store::decline_call` to `Db::decline_call` --- crates/collab/src/db.rs | 18 ++++++++++++++++ crates/collab/src/rpc.rs | 28 +++++++++++++++---------- crates/collab/src/rpc/store.rs | 38 +++------------------------------- 3 files changed, 38 insertions(+), 46 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index a98621d8942f54e0a92848d6c2bc8080ed998356..10f1dd04424ba7ed1c0fa07e685c285f8769b1eb 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1027,6 +1027,24 @@ where }) } + pub async fn decline_call(&self, room_id: RoomId, user_id: UserId) -> Result { + test_support!(self, { + let mut tx = self.pool.begin().await?; + sqlx::query( + " + DELETE FROM room_participants + WHERE room_id = $1 AND user_id = $2 AND connection_id IS NULL + ", + ) + .bind(room_id) + .bind(user_id) + .execute(&mut tx) + .await?; + + self.commit_room_transaction(room_id, tx).await + }) + } + pub async fn join_room( &self, room_id: RoomId, diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index c7c222ee1c30511f6bf296e62b502d5025292b2f..652ac5917b0e80b7745b8fd17673db8e4dcc1753 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -800,19 +800,25 @@ impl Server { } async fn decline_call(self: Arc, message: Message) -> Result<()> { - let recipient_user_id = message.sender_user_id; + let room = self + .app_state + .db + .decline_call( + RoomId::from_proto(message.payload.room_id), + message.sender_user_id, + ) + .await?; + for recipient_id in self + .store() + .await + .connection_ids_for_user(message.sender_user_id) { - let mut store = self.store().await; - let (room, recipient_connection_ids) = - store.decline_call(message.payload.room_id, message.sender_connection_id)?; - for recipient_id in recipient_connection_ids { - self.peer - .send(recipient_id, proto::CallCanceled {}) - .trace_err(); - } - self.room_updated(room); + self.peer + .send(recipient_id, proto::CallCanceled {}) + .trace_err(); } - self.update_user_contacts(recipient_user_id).await?; + self.room_updated(&room); + self.update_user_contacts(message.sender_user_id).await?; Ok(()) } diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index 610a653dc9841fdca10f672789aafb3735558f99..d64464f601be06f89ac583e3edd26e08716ac806 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -162,8 +162,9 @@ impl Store { result.room = Some(Cow::Owned(left_room.room.into_owned())); result.canceled_call_connection_ids = left_room.canceled_call_connection_ids; } else if connected_user.connection_ids.len() == 1 { - let (room, _) = self.decline_call(room_id, connection_id)?; - result.room = Some(Cow::Owned(room.clone())); + todo!() + // let (room, _) = self.decline_call(room_id, connection_id)?; + // result.room = Some(Cow::Owned(room.clone())); } } @@ -390,39 +391,6 @@ impl Store { // Ok((room, recipient.connection_ids.clone())) } - pub fn decline_call( - &mut self, - room_id: RoomId, - recipient_connection_id: ConnectionId, - ) -> Result<(&proto::Room, Vec)> { - todo!() - // let called_user_id = self.user_id_for_connection(recipient_connection_id)?; - // let recipient = self - // .connected_users - // .get_mut(&called_user_id) - // .ok_or_else(|| anyhow!("no such connection"))?; - // if let Some(active_call) = recipient.active_call { - // anyhow::ensure!(active_call.room_id == room_id, "no such room"); - // anyhow::ensure!( - // active_call.connection_id.is_none(), - // "cannot decline a call after joining room" - // ); - // recipient.active_call.take(); - // let recipient_connection_ids = self - // .connection_ids_for_user(called_user_id) - // .collect::>(); - // let room = self - // .rooms - // .get_mut(&active_call.room_id) - // .ok_or_else(|| anyhow!("no such room"))?; - // room.pending_participant_user_ids - // .retain(|user_id| UserId::from_proto(*user_id) != called_user_id); - // Ok((room, recipient_connection_ids)) - // } else { - // Err(anyhow!("user is not being called")) - // } - } - pub fn unshare_project( &mut self, project_id: ProjectId, From 1135aeecb8b9640111bc1e0c5566c8b3b64b7e4e Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 11 Nov 2022 16:59:54 +0100 Subject: [PATCH 008/109] WIP: Move `Store::leave_room` to `Db::leave_room` --- .../20221109000000_test_schema.sql | 4 +- .../20221111092550_reconnection_support.sql | 4 +- crates/collab/src/db.rs | 112 ++++++++++++++++++ crates/collab/src/rpc.rs | 71 ++++++----- crates/collab/src/rpc/store.rs | 96 +-------------- 5 files changed, 161 insertions(+), 126 deletions(-) diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 5b38ebf8b1e9d88a12c18064f973aed183cf045b..44495f16ce368dd877df0cc6c04eef95fba04fa1 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -48,7 +48,7 @@ CREATE TABLE "projects" ( CREATE TABLE "project_collaborators" ( "id" INTEGER PRIMARY KEY, - "project_id" INTEGER NOT NULL REFERENCES projects (id), + "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, "connection_id" INTEGER NOT NULL, "user_id" INTEGER NOT NULL, "replica_id" INTEGER NOT NULL, @@ -58,7 +58,7 @@ CREATE INDEX "index_project_collaborators_on_project_id" ON "project_collaborato CREATE TABLE "worktrees" ( "id" INTEGER NOT NULL, - "project_id" INTEGER NOT NULL REFERENCES projects (id), + "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, "root_name" VARCHAR NOT NULL, PRIMARY KEY(project_id, id) ); diff --git a/crates/collab/migrations/20221111092550_reconnection_support.sql b/crates/collab/migrations/20221111092550_reconnection_support.sql index 621512bf43b3c2f39ce54a478f8e82825fc37cfd..ed6da2b7b14a3f31fdfffba4a80af286002f4739 100644 --- a/crates/collab/migrations/20221111092550_reconnection_support.sql +++ b/crates/collab/migrations/20221111092550_reconnection_support.sql @@ -10,7 +10,7 @@ ALTER TABLE "projects" CREATE TABLE "project_collaborators" ( "id" SERIAL PRIMARY KEY, - "project_id" INTEGER NOT NULL REFERENCES projects (id), + "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, "connection_id" INTEGER NOT NULL, "user_id" INTEGER NOT NULL, "replica_id" INTEGER NOT NULL, @@ -20,7 +20,7 @@ CREATE INDEX "index_project_collaborators_on_project_id" ON "project_collaborato CREATE TABLE IF NOT EXISTS "worktrees" ( "id" INTEGER NOT NULL, - "project_id" INTEGER NOT NULL REFERENCES projects (id), + "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, "root_name" VARCHAR NOT NULL, PRIMARY KEY(project_id, id) ); diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 10f1dd04424ba7ed1c0fa07e685c285f8769b1eb..fc5e3c242b9ae05f2cbffe2615fc2edcaac77b8e 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1070,6 +1070,97 @@ where }) } + pub async fn leave_room( + &self, + room_id: RoomId, + connection_id: ConnectionId, + ) -> Result { + test_support!(self, { + let mut tx = self.pool.begin().await?; + + // Leave room. + let user_id: UserId = sqlx::query_scalar( + " + DELETE FROM room_participants + WHERE room_id = $1 AND connection_id = $2 + RETURNING user_id + ", + ) + .bind(room_id) + .bind(connection_id.0 as i32) + .fetch_one(&mut tx) + .await?; + + // Cancel pending calls initiated by the leaving user. + let canceled_calls_to_user_ids: Vec = sqlx::query_scalar( + " + DELETE FROM room_participants + WHERE calling_user_id = $1 AND connection_id IS NULL + RETURNING user_id + ", + ) + .bind(room_id) + .bind(connection_id.0 as i32) + .fetch_all(&mut tx) + .await?; + + let mut project_collaborators = sqlx::query_as::<_, ProjectCollaborator>( + " + SELECT project_collaborators.* + FROM projects, project_collaborators + WHERE + projects.room_id = $1 AND + projects.user_id = $2 AND + projects.id = project_collaborators.project_id + ", + ) + .bind(room_id) + .bind(user_id) + .fetch(&mut tx); + + let mut left_projects = HashMap::default(); + while let Some(collaborator) = project_collaborators.next().await { + let collaborator = collaborator?; + let left_project = + left_projects + .entry(collaborator.project_id) + .or_insert(LeftProject { + id: collaborator.project_id, + host_user_id: Default::default(), + connection_ids: Default::default(), + }); + + let collaborator_connection_id = ConnectionId(collaborator.connection_id as u32); + if collaborator_connection_id != connection_id || collaborator.is_host { + left_project.connection_ids.push(collaborator_connection_id); + } + + if collaborator.is_host { + left_project.host_user_id = collaborator.user_id; + } + } + drop(project_collaborators); + + sqlx::query( + " + DELETE FROM projects + WHERE room_id = $1 AND user_id = $2 + ", + ) + .bind(room_id) + .bind(user_id) + .execute(&mut tx) + .await?; + + let room = self.commit_room_transaction(room_id, tx).await?; + Ok(LeftRoom { + room, + left_projects, + canceled_calls_to_user_ids, + }) + }) + } + pub async fn update_room_participant_location( &self, room_id: RoomId, @@ -1667,6 +1758,27 @@ pub struct Project { pub unregistered: bool, } +#[derive(Clone, Debug, Default, FromRow, PartialEq)] +pub struct ProjectCollaborator { + pub project_id: ProjectId, + pub connection_id: i32, + pub user_id: UserId, + pub replica_id: i32, + pub is_host: bool, +} + +pub struct LeftProject { + pub id: ProjectId, + pub host_user_id: UserId, + pub connection_ids: Vec, +} + +pub struct LeftRoom { + pub room: proto::Room, + pub left_projects: HashMap, + pub canceled_calls_to_user_ids: Vec, +} + #[derive(Clone, Debug, PartialEq, Eq)] pub enum Contact { Accepted { diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 652ac5917b0e80b7745b8fd17673db8e4dcc1753..1221964601592223962f54e9ed5be38d24c7faaa 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -658,14 +658,20 @@ impl Server { async fn leave_room(self: Arc, message: Message) -> Result<()> { let mut contacts_to_update = HashSet::default(); - let room_left; - { - let mut store = self.store().await; - let left_room = store.leave_room(message.payload.id, message.sender_connection_id)?; - contacts_to_update.insert(message.sender_user_id); - for project in left_room.unshared_projects { - for connection_id in project.connection_ids() { + let left_room = self + .app_state + .db + .leave_room( + RoomId::from_proto(message.payload.id), + message.sender_connection_id, + ) + .await?; + contacts_to_update.insert(message.sender_user_id); + + for project in left_room.left_projects.into_values() { + if project.host_user_id == message.sender_user_id { + for connection_id in project.connection_ids { self.peer.send( connection_id, proto::UnshareProject { @@ -673,41 +679,42 @@ impl Server { }, )?; } - } - - for project in left_room.left_projects { - if project.remove_collaborator { - for connection_id in project.connection_ids { - self.peer.send( - connection_id, - proto::RemoveProjectCollaborator { - project_id: project.id.to_proto(), - peer_id: message.sender_connection_id.0, - }, - )?; - } - + } else { + for connection_id in project.connection_ids { self.peer.send( - message.sender_connection_id, - proto::UnshareProject { + connection_id, + proto::RemoveProjectCollaborator { project_id: project.id.to_proto(), + peer_id: message.sender_connection_id.0, }, )?; } - } - self.room_updated(&left_room.room); - room_left = self.room_left(&left_room.room, message.sender_connection_id); + self.peer.send( + message.sender_connection_id, + proto::UnshareProject { + project_id: project.id.to_proto(), + }, + )?; + } + } - for connection_id in left_room.canceled_call_connection_ids { - self.peer - .send(connection_id, proto::CallCanceled {}) - .trace_err(); - contacts_to_update.extend(store.user_id_for_connection(connection_id).ok()); + self.room_updated(&left_room.room); + { + let store = self.store().await; + for user_id in left_room.canceled_calls_to_user_ids { + for connection_id in store.connection_ids_for_user(user_id) { + self.peer + .send(connection_id, proto::CallCanceled {}) + .trace_err(); + } + contacts_to_update.insert(user_id); } } - room_left.await.trace_err(); + self.room_left(&left_room.room, message.sender_connection_id) + .await + .trace_err(); for user_id in contacts_to_update { self.update_user_contacts(user_id).await?; } diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index d64464f601be06f89ac583e3edd26e08716ac806..4ea2c7b38ef9b78c9ca02ad4653f1707a0781387 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -90,13 +90,6 @@ pub struct LeftProject { pub remove_collaborator: bool, } -pub struct LeftRoom<'a> { - pub room: Cow<'a, proto::Room>, - pub unshared_projects: Vec, - pub left_projects: Vec, - pub canceled_call_connection_ids: Vec, -} - #[derive(Copy, Clone)] pub struct Metrics { pub connections: usize, @@ -156,11 +149,12 @@ impl Store { if let Some(active_call) = connected_user.active_call.as_ref() { let room_id = active_call.room_id; if active_call.connection_id == Some(connection_id) { - let left_room = self.leave_room(room_id, connection_id)?; - result.hosted_projects = left_room.unshared_projects; - result.guest_projects = left_room.left_projects; - result.room = Some(Cow::Owned(left_room.room.into_owned())); - result.canceled_call_connection_ids = left_room.canceled_call_connection_ids; + todo!() + // let left_room = self.leave_room(room_id, connection_id)?; + // result.hosted_projects = left_room.unshared_projects; + // result.guest_projects = left_room.left_projects; + // result.room = Some(Cow::Owned(left_room.room.into_owned())); + // result.canceled_call_connection_ids = left_room.canceled_call_connection_ids; } else if connected_user.connection_ids.len() == 1 { todo!() // let (room, _) = self.decline_call(room_id, connection_id)?; @@ -258,84 +252,6 @@ impl Store { } } - pub fn leave_room(&mut self, room_id: RoomId, connection_id: ConnectionId) -> Result { - todo!() - // let connection = self - // .connections - // .get_mut(&connection_id) - // .ok_or_else(|| anyhow!("no such connection"))?; - // let user_id = connection.user_id; - - // let connected_user = self - // .connected_users - // .get(&user_id) - // .ok_or_else(|| anyhow!("no such connection"))?; - // anyhow::ensure!( - // connected_user - // .active_call - // .map_or(false, |call| call.room_id == room_id - // && call.connection_id == Some(connection_id)), - // "cannot leave a room before joining it" - // ); - - // // Given that users can only join one room at a time, we can safely unshare - // // and leave all projects associated with the connection. - // let mut unshared_projects = Vec::new(); - // let mut left_projects = Vec::new(); - // for project_id in connection.projects.clone() { - // if let Ok((_, project)) = self.unshare_project(project_id, connection_id) { - // unshared_projects.push(project); - // } else if let Ok(project) = self.leave_project(project_id, connection_id) { - // left_projects.push(project); - // } - // } - // self.connected_users.get_mut(&user_id).unwrap().active_call = None; - - // let room = self - // .rooms - // .get_mut(&room_id) - // .ok_or_else(|| anyhow!("no such room"))?; - // room.participants - // .retain(|participant| participant.peer_id != connection_id.0); - - // let mut canceled_call_connection_ids = Vec::new(); - // room.pending_participant_user_ids - // .retain(|pending_participant_user_id| { - // if let Some(connected_user) = self - // .connected_users - // .get_mut(&UserId::from_proto(*pending_participant_user_id)) - // { - // if let Some(call) = connected_user.active_call.as_ref() { - // if call.calling_user_id == user_id { - // connected_user.active_call.take(); - // canceled_call_connection_ids - // .extend(connected_user.connection_ids.iter().copied()); - // false - // } else { - // true - // } - // } else { - // true - // } - // } else { - // true - // } - // }); - - // let room = if room.participants.is_empty() { - // Cow::Owned(self.rooms.remove(&room_id).unwrap()) - // } else { - // Cow::Borrowed(self.rooms.get(&room_id).unwrap()) - // }; - - // Ok(LeftRoom { - // room, - // unshared_projects, - // left_projects, - // canceled_call_connection_ids, - // }) - } - pub fn rooms(&self) -> &BTreeMap { &self.rooms } From 9f39dcf7cf1dc589efe93b4815976ffc95118cb1 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 11 Nov 2022 18:53:23 +0100 Subject: [PATCH 009/109] Get basic calls test passing again --- crates/collab/src/db.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index fc5e3c242b9ae05f2cbffe2615fc2edcaac77b8e..e092bd950149e1b9cbfec6ce3a6e676442cee8f0 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1110,7 +1110,7 @@ where FROM projects, project_collaborators WHERE projects.room_id = $1 AND - projects.user_id = $2 AND + projects.host_user_id = $2 AND projects.id = project_collaborators.project_id ", ) @@ -1144,7 +1144,7 @@ where sqlx::query( " DELETE FROM projects - WHERE room_id = $1 AND user_id = $2 + WHERE room_id = $1 AND host_user_id = $2 ", ) .bind(room_id) From 11caba4a4c8b536fb6c0d3d0eea3f08c57cfce67 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 11 Nov 2022 18:54:08 +0100 Subject: [PATCH 010/109] Remove stray log statement --- crates/collab/src/integration_tests.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index 6d3cff1718e983812539374373b540ef5ba6f27f..3a4c2368e8060994482b464dfbe7081a080e3efc 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -71,8 +71,6 @@ async fn test_basic_calls( deterministic.forbid_parking(); let mut server = TestServer::start(cx_a.background()).await; - let start = std::time::Instant::now(); - let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; let client_c = server.create_client(cx_c, "user_c").await; @@ -258,8 +256,6 @@ async fn test_basic_calls( pending: Default::default() } ); - - eprintln!("finished test {:?}", start.elapsed()); } #[gpui::test(iterations = 10)] From 2145965749b0edff3972fff1124d31cf3ff55348 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 11 Nov 2022 19:36:20 +0100 Subject: [PATCH 011/109] WIP --- .../20221109000000_test_schema.sql | 3 +- .../20221111092550_reconnection_support.sql | 4 +- crates/collab/src/db.rs | 13 +----- crates/collab/src/rpc.rs | 8 +--- crates/collab/src/rpc/store.rs | 41 +------------------ 5 files changed, 10 insertions(+), 59 deletions(-) diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 44495f16ce368dd877df0cc6c04eef95fba04fa1..477cc5d6075458c666f410013ae53d5b124f8767 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -68,10 +68,11 @@ CREATE TABLE "room_participants" ( "id" INTEGER PRIMARY KEY, "room_id" INTEGER NOT NULL REFERENCES rooms (id), "user_id" INTEGER NOT NULL REFERENCES users (id), - "connection_id" INTEGER, + "answering_connection_id" INTEGER, "location_kind" INTEGER, "location_project_id" INTEGER REFERENCES projects (id), "initial_project_id" INTEGER REFERENCES projects (id), "calling_user_id" INTEGER NOT NULL REFERENCES users (id) + "calling_connection_id" INTEGER NOT NULL ); CREATE UNIQUE INDEX "index_room_participants_on_user_id" ON "room_participants" ("user_id"); diff --git a/crates/collab/migrations/20221111092550_reconnection_support.sql b/crates/collab/migrations/20221111092550_reconnection_support.sql index ed6da2b7b14a3f31fdfffba4a80af286002f4739..48e6b50b06c0307170fd673542b3857d92c8b7f8 100644 --- a/crates/collab/migrations/20221111092550_reconnection_support.sql +++ b/crates/collab/migrations/20221111092550_reconnection_support.sql @@ -34,7 +34,7 @@ CREATE TABLE IF NOT EXISTS "room_participants" ( "location_kind" INTEGER, "location_project_id" INTEGER REFERENCES projects (id), "initial_project_id" INTEGER REFERENCES projects (id), - "calling_user_id" INTEGER NOT NULL REFERENCES users (id) + "calling_user_id" INTEGER NOT NULL REFERENCES users (id), + "calling_connection_id" INTEGER NOT NULL ); CREATE UNIQUE INDEX "index_room_participants_on_user_id" ON "room_participants" ("user_id"); - diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index e092bd950149e1b9cbfec6ce3a6e676442cee8f0..3ffdc602dad86effa739b791257cfc02e7416c75 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1342,7 +1342,7 @@ where INSERT INTO projects (host_user_id, room_id) VALUES ($1) RETURNING id - ", + ", ) .bind(user_id) .bind(room_id) @@ -1354,7 +1354,7 @@ where sqlx::query( " INSERT INTO worktrees (id, project_id, root_name) - ", + ", ) .bind(worktree.id as i32) .bind(project_id) @@ -1741,15 +1741,6 @@ pub struct Room { pub live_kit_room: String, } -#[derive(Clone, Debug, Default, FromRow, PartialEq)] -pub struct Call { - pub room_id: RoomId, - pub calling_user_id: UserId, - pub called_user_id: UserId, - pub answering_connection_id: Option, - pub initial_project_id: Option, -} - id_type!(ProjectId); #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)] pub struct Project { diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 1221964601592223962f54e9ed5be38d24c7faaa..5b713226b1859e66473409c3ed4d43a6b84dcc24 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -431,12 +431,8 @@ impl Server { let mut contacts_to_update = HashSet::default(); let mut room_left = None; { - let mut store = self.store().await; - - #[cfg(test)] - let removed_connection = store.remove_connection(connection_id).unwrap(); - #[cfg(not(test))] - let removed_connection = store.remove_connection(connection_id)?; + let removed_connection = self.store().await.remove_connection(connection_id)?; + self.app_state.db.remove_connection(connection_id); for project in removed_connection.hosted_projects { projects_to_unshare.push(project.id); diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index 4ea2c7b38ef9b78c9ca02ad4653f1707a0781387..de444924091d3c6ce7013b872121f85a3fc03bb4 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -72,16 +72,6 @@ pub struct Worktree { pub type ReplicaId = u16; -#[derive(Default)] -pub struct RemovedConnectionState<'a> { - pub user_id: UserId, - pub hosted_projects: Vec, - pub guest_projects: Vec, - pub contact_ids: HashSet, - pub room: Option>, - pub canceled_call_connection_ids: Vec, -} - pub struct LeftProject { pub id: ProjectId, pub host_user_id: UserId, @@ -129,47 +119,20 @@ impl Store { } #[instrument(skip(self))] - pub fn remove_connection( - &mut self, - connection_id: ConnectionId, - ) -> Result { + pub fn remove_connection(&mut self, connection_id: ConnectionId) -> Result<()> { let connection = self .connections .get_mut(&connection_id) .ok_or_else(|| anyhow!("no such connection"))?; let user_id = connection.user_id; - - let mut result = RemovedConnectionState { - user_id, - ..Default::default() - }; - - let connected_user = self.connected_users.get(&user_id).unwrap(); - if let Some(active_call) = connected_user.active_call.as_ref() { - let room_id = active_call.room_id; - if active_call.connection_id == Some(connection_id) { - todo!() - // let left_room = self.leave_room(room_id, connection_id)?; - // result.hosted_projects = left_room.unshared_projects; - // result.guest_projects = left_room.left_projects; - // result.room = Some(Cow::Owned(left_room.room.into_owned())); - // result.canceled_call_connection_ids = left_room.canceled_call_connection_ids; - } else if connected_user.connection_ids.len() == 1 { - todo!() - // let (room, _) = self.decline_call(room_id, connection_id)?; - // result.room = Some(Cow::Owned(room.clone())); - } - } - let connected_user = self.connected_users.get_mut(&user_id).unwrap(); connected_user.connection_ids.remove(&connection_id); if connected_user.connection_ids.is_empty() { self.connected_users.remove(&user_id); } self.connections.remove(&connection_id).unwrap(); - - Ok(result) + Ok(()) } pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> Result { From 9902211af18da0979055de6d1c611e58973deed9 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 14 Nov 2022 10:13:36 +0100 Subject: [PATCH 012/109] Leave room when connection is dropped --- crates/call/src/room.rs | 4 +- .../20221109000000_test_schema.sql | 5 +- .../20221111092550_reconnection_support.sql | 3 +- crates/collab/src/db.rs | 189 ++++++++-------- crates/collab/src/rpc.rs | 202 +++++++----------- crates/collab/src/rpc/store.rs | 10 +- crates/rpc/proto/zed.proto | 4 +- 7 files changed, 183 insertions(+), 234 deletions(-) diff --git a/crates/call/src/room.rs b/crates/call/src/room.rs index 4f3079e72c1e75ab1cbd5818eeabf3151a8c21a1..0ecd6082d63f576be7c4f3342aa679921600e873 100644 --- a/crates/call/src/room.rs +++ b/crates/call/src/room.rs @@ -53,7 +53,7 @@ impl Entity for Room { fn release(&mut self, _: &mut MutableAppContext) { if self.status.is_online() { - self.client.send(proto::LeaveRoom { id: self.id }).log_err(); + self.client.send(proto::LeaveRoom {}).log_err(); } } } @@ -241,7 +241,7 @@ impl Room { self.participant_user_ids.clear(); self.subscriptions.clear(); self.live_kit.take(); - self.client.send(proto::LeaveRoom { id: self.id })?; + self.client.send(proto::LeaveRoom {})?; Ok(()) } diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 477cc5d6075458c666f410013ae53d5b124f8767..2cef514e5a1810026d62047cef6b61c817e33155 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -43,7 +43,8 @@ CREATE TABLE "rooms" ( CREATE TABLE "projects" ( "id" INTEGER PRIMARY KEY, "room_id" INTEGER REFERENCES rooms (id), - "host_user_id" INTEGER REFERENCES users (id) NOT NULL + "host_user_id" INTEGER REFERENCES users (id) NOT NULL, + "host_connection_id" INTEGER NOT NULL ); CREATE TABLE "project_collaborators" ( @@ -72,7 +73,7 @@ CREATE TABLE "room_participants" ( "location_kind" INTEGER, "location_project_id" INTEGER REFERENCES projects (id), "initial_project_id" INTEGER REFERENCES projects (id), - "calling_user_id" INTEGER NOT NULL REFERENCES users (id) + "calling_user_id" INTEGER NOT NULL REFERENCES users (id), "calling_connection_id" INTEGER NOT NULL ); CREATE UNIQUE INDEX "index_room_participants_on_user_id" ON "room_participants" ("user_id"); diff --git a/crates/collab/migrations/20221111092550_reconnection_support.sql b/crates/collab/migrations/20221111092550_reconnection_support.sql index 48e6b50b06c0307170fd673542b3857d92c8b7f8..7b82ce9ce7f49ec953a2c8ef54e2cdbfe07d3274 100644 --- a/crates/collab/migrations/20221111092550_reconnection_support.sql +++ b/crates/collab/migrations/20221111092550_reconnection_support.sql @@ -6,6 +6,7 @@ CREATE TABLE IF NOT EXISTS "rooms" ( ALTER TABLE "projects" ADD "room_id" INTEGER REFERENCES rooms (id), + ADD "host_connection_id" INTEGER, DROP COLUMN "unregistered"; CREATE TABLE "project_collaborators" ( @@ -30,7 +31,7 @@ CREATE TABLE IF NOT EXISTS "room_participants" ( "id" SERIAL PRIMARY KEY, "room_id" INTEGER NOT NULL REFERENCES rooms (id), "user_id" INTEGER NOT NULL REFERENCES users (id), - "connection_id" INTEGER, + "answering_connection_id" INTEGER, "location_kind" INTEGER, "location_project_id" INTEGER REFERENCES projects (id), "initial_project_id" INTEGER REFERENCES projects (id), diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 3ffdc602dad86effa739b791257cfc02e7416c75..f32bdf96eff4725f93d2154f7f9c85336fb21340 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -907,14 +907,15 @@ where sqlx::query( " - INSERT INTO room_participants (room_id, user_id, connection_id, calling_user_id) - VALUES ($1, $2, $3, $4) + INSERT INTO room_participants (room_id, user_id, answering_connection_id, calling_user_id, calling_connection_id) + VALUES ($1, $2, $3, $4, $5) ", ) .bind(room_id) .bind(user_id) .bind(connection_id.0 as i32) .bind(user_id) + .bind(connection_id.0 as i32) .execute(&mut tx) .await?; @@ -926,6 +927,7 @@ where &self, room_id: RoomId, calling_user_id: UserId, + calling_connection_id: ConnectionId, called_user_id: UserId, initial_project_id: Option, ) -> Result<(proto::Room, proto::IncomingCall)> { @@ -933,13 +935,14 @@ where let mut tx = self.pool.begin().await?; sqlx::query( " - INSERT INTO room_participants (room_id, user_id, calling_user_id, initial_project_id) - VALUES ($1, $2, $3, $4) + INSERT INTO room_participants (room_id, user_id, calling_user_id, calling_connection_id, initial_project_id) + VALUES ($1, $2, $3, $4, $5) ", ) .bind(room_id) .bind(called_user_id) .bind(calling_user_id) + .bind(calling_connection_id.0 as i32) .bind(initial_project_id) .execute(&mut tx) .await?; @@ -961,7 +964,7 @@ where " SELECT room_id FROM room_participants - WHERE user_id = $1 AND connection_id IS NULL + WHERE user_id = $1 AND answering_connection_id IS NULL ", ) .bind(user_id) @@ -1033,7 +1036,7 @@ where sqlx::query( " DELETE FROM room_participants - WHERE room_id = $1 AND user_id = $2 AND connection_id IS NULL + WHERE room_id = $1 AND user_id = $2 AND answering_connection_id IS NULL ", ) .bind(room_id) @@ -1056,7 +1059,7 @@ where sqlx::query( " UPDATE room_participants - SET connection_id = $1 + SET answering_connection_id = $1 WHERE room_id = $2 AND user_id = $3 RETURNING 1 ", @@ -1070,101 +1073,100 @@ where }) } - pub async fn leave_room( - &self, - room_id: RoomId, - connection_id: ConnectionId, - ) -> Result { + pub async fn leave_room(&self, connection_id: ConnectionId) -> Result> { test_support!(self, { let mut tx = self.pool.begin().await?; // Leave room. - let user_id: UserId = sqlx::query_scalar( - " - DELETE FROM room_participants - WHERE room_id = $1 AND connection_id = $2 - RETURNING user_id - ", - ) - .bind(room_id) - .bind(connection_id.0 as i32) - .fetch_one(&mut tx) - .await?; - - // Cancel pending calls initiated by the leaving user. - let canceled_calls_to_user_ids: Vec = sqlx::query_scalar( + let room_id = sqlx::query_scalar::<_, RoomId>( " DELETE FROM room_participants - WHERE calling_user_id = $1 AND connection_id IS NULL - RETURNING user_id + WHERE answering_connection_id = $1 + RETURNING room_id ", ) - .bind(room_id) .bind(connection_id.0 as i32) - .fetch_all(&mut tx) + .fetch_optional(&mut tx) .await?; - let mut project_collaborators = sqlx::query_as::<_, ProjectCollaborator>( - " - SELECT project_collaborators.* - FROM projects, project_collaborators - WHERE - projects.room_id = $1 AND - projects.host_user_id = $2 AND - projects.id = project_collaborators.project_id - ", - ) - .bind(room_id) - .bind(user_id) - .fetch(&mut tx); - - let mut left_projects = HashMap::default(); - while let Some(collaborator) = project_collaborators.next().await { - let collaborator = collaborator?; - let left_project = - left_projects - .entry(collaborator.project_id) - .or_insert(LeftProject { - id: collaborator.project_id, - host_user_id: Default::default(), - connection_ids: Default::default(), - }); + if let Some(room_id) = room_id { + // Cancel pending calls initiated by the leaving user. + let canceled_calls_to_user_ids: Vec = sqlx::query_scalar( + " + DELETE FROM room_participants + WHERE calling_connection_id = $1 AND answering_connection_id IS NULL + RETURNING user_id + ", + ) + .bind(connection_id.0 as i32) + .fetch_all(&mut tx) + .await?; - let collaborator_connection_id = ConnectionId(collaborator.connection_id as u32); - if collaborator_connection_id != connection_id || collaborator.is_host { - left_project.connection_ids.push(collaborator_connection_id); - } + let mut project_collaborators = sqlx::query_as::<_, ProjectCollaborator>( + " + SELECT project_collaborators.* + FROM projects, project_collaborators + WHERE + projects.room_id = $1 AND + projects.host_connection_id = $2 AND + projects.id = project_collaborators.project_id + ", + ) + .bind(room_id) + .bind(connection_id.0 as i32) + .fetch(&mut tx); + + let mut left_projects = HashMap::default(); + while let Some(collaborator) = project_collaborators.next().await { + let collaborator = collaborator?; + let left_project = + left_projects + .entry(collaborator.project_id) + .or_insert(LeftProject { + id: collaborator.project_id, + host_user_id: Default::default(), + connection_ids: Default::default(), + }); + + let collaborator_connection_id = + ConnectionId(collaborator.connection_id as u32); + if collaborator_connection_id != connection_id || collaborator.is_host { + left_project.connection_ids.push(collaborator_connection_id); + } - if collaborator.is_host { - left_project.host_user_id = collaborator.user_id; + if collaborator.is_host { + left_project.host_user_id = collaborator.user_id; + } } - } - drop(project_collaborators); + drop(project_collaborators); - sqlx::query( - " - DELETE FROM projects - WHERE room_id = $1 AND host_user_id = $2 - ", - ) - .bind(room_id) - .bind(user_id) - .execute(&mut tx) - .await?; + sqlx::query( + " + DELETE FROM projects + WHERE room_id = $1 AND host_connection_id = $2 + ", + ) + .bind(room_id) + .bind(connection_id.0 as i32) + .execute(&mut tx) + .await?; - let room = self.commit_room_transaction(room_id, tx).await?; - Ok(LeftRoom { - room, - left_projects, - canceled_calls_to_user_ids, - }) + let room = self.commit_room_transaction(room_id, tx).await?; + Ok(Some(LeftRoom { + room, + left_projects, + canceled_calls_to_user_ids, + })) + } else { + Ok(None) + } }) } pub async fn update_room_participant_location( &self, room_id: RoomId, - user_id: UserId, + connection_id: ConnectionId, location: proto::ParticipantLocation, ) -> Result { test_support!(self, { @@ -1194,13 +1196,13 @@ where " UPDATE room_participants SET location_kind = $1 AND location_project_id = $2 - WHERE room_id = $1 AND user_id = $2 + WHERE room_id = $3 AND answering_connection_id = $4 ", ) .bind(location_kind) .bind(location_project_id) .bind(room_id) - .bind(user_id) + .bind(connection_id.0 as i32) .execute(&mut tx) .await?; @@ -1248,7 +1250,7 @@ where let mut db_participants = sqlx::query_as::<_, (UserId, Option, Option, Option, UserId, Option)>( " - SELECT user_id, connection_id, location_kind, location_project_id, calling_user_id, initial_project_id + SELECT user_id, answering_connection_id, location_kind, location_project_id, calling_user_id, initial_project_id FROM room_participants WHERE room_id = $1 ", @@ -1261,16 +1263,16 @@ where while let Some(participant) = db_participants.next().await { let ( user_id, - connection_id, + answering_connection_id, _location_kind, _location_project_id, calling_user_id, initial_project_id, ) = participant?; - if let Some(connection_id) = connection_id { + if let Some(answering_connection_id) = answering_connection_id { participants.push(proto::Participant { user_id: user_id.to_proto(), - peer_id: connection_id as u32, + peer_id: answering_connection_id as u32, projects: Default::default(), location: Some(proto::ParticipantLocation { variant: Some(proto::participant_location::Variant::External( @@ -1339,12 +1341,13 @@ where let mut tx = self.pool.begin().await?; let project_id = sqlx::query_scalar( " - INSERT INTO projects (host_user_id, room_id) - VALUES ($1) + INSERT INTO projects (host_user_id, host_connection_id, room_id) + VALUES ($1, $2, $3) RETURNING id ", ) .bind(user_id) + .bind(connection_id.0 as i32) .bind(room_id) .fetch_one(&mut tx) .await @@ -1366,11 +1369,11 @@ where sqlx::query( " INSERT INTO project_collaborators ( - project_id, - connection_id, - user_id, - replica_id, - is_host + project_id, + connection_id, + user_id, + replica_id, + is_host ) VALUES ($1, $2, $3, $4, $5) ", diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 5b713226b1859e66473409c3ed4d43a6b84dcc24..e69393c642eda26480119a99a7506efa913a7c13 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -415,7 +415,7 @@ impl Server { drop(foreground_message_handlers); tracing::info!(%user_id, %login, %connection_id, %address, "signing out"); - if let Err(error) = this.sign_out(connection_id).await { + if let Err(error) = this.sign_out(connection_id, user_id).await { tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out"); } @@ -424,69 +424,15 @@ impl Server { } #[instrument(skip(self), err)] - async fn sign_out(self: &mut Arc, connection_id: ConnectionId) -> Result<()> { + async fn sign_out( + self: &mut Arc, + connection_id: ConnectionId, + user_id: UserId, + ) -> Result<()> { self.peer.disconnect(connection_id); - - let mut projects_to_unshare = Vec::new(); - let mut contacts_to_update = HashSet::default(); - let mut room_left = None; - { - let removed_connection = self.store().await.remove_connection(connection_id)?; - self.app_state.db.remove_connection(connection_id); - - for project in removed_connection.hosted_projects { - projects_to_unshare.push(project.id); - broadcast(connection_id, project.guests.keys().copied(), |conn_id| { - self.peer.send( - conn_id, - proto::UnshareProject { - project_id: project.id.to_proto(), - }, - ) - }); - } - - for project in removed_connection.guest_projects { - broadcast(connection_id, project.connection_ids, |conn_id| { - self.peer.send( - conn_id, - proto::RemoveProjectCollaborator { - project_id: project.id.to_proto(), - peer_id: connection_id.0, - }, - ) - }); - } - - if let Some(room) = removed_connection.room { - self.room_updated(&room); - room_left = Some(self.room_left(&room, connection_id)); - } - - contacts_to_update.insert(removed_connection.user_id); - for connection_id in removed_connection.canceled_call_connection_ids { - self.peer - .send(connection_id, proto::CallCanceled {}) - .trace_err(); - contacts_to_update.extend(store.user_id_for_connection(connection_id).ok()); - } - }; - - if let Some(room_left) = room_left { - room_left.await.trace_err(); - } - - for user_id in contacts_to_update { - self.update_user_contacts(user_id).await.trace_err(); - } - - for project_id in projects_to_unshare { - self.app_state - .db - .unshare_project(project_id) - .await - .trace_err(); - } + self.store().await.remove_connection(connection_id)?; + self.leave_room_for_connection(connection_id, user_id) + .await?; Ok(()) } @@ -653,66 +599,90 @@ impl Server { } async fn leave_room(self: Arc, message: Message) -> Result<()> { + self.leave_room_for_connection(message.sender_connection_id, message.sender_user_id) + .await + } + + async fn leave_room_for_connection( + self: &Arc, + connection_id: ConnectionId, + user_id: UserId, + ) -> Result<()> { let mut contacts_to_update = HashSet::default(); - let left_room = self - .app_state - .db - .leave_room( - RoomId::from_proto(message.payload.id), - message.sender_connection_id, - ) - .await?; - contacts_to_update.insert(message.sender_user_id); + let Some(left_room) = self.app_state.db.leave_room(connection_id).await? else { + return Err(anyhow!("no room to leave"))?; + }; + contacts_to_update.insert(user_id); for project in left_room.left_projects.into_values() { - if project.host_user_id == message.sender_user_id { + if project.host_user_id == user_id { for connection_id in project.connection_ids { - self.peer.send( - connection_id, - proto::UnshareProject { - project_id: project.id.to_proto(), - }, - )?; + self.peer + .send( + connection_id, + proto::UnshareProject { + project_id: project.id.to_proto(), + }, + ) + .trace_err(); } } else { for connection_id in project.connection_ids { - self.peer.send( + self.peer + .send( + connection_id, + proto::RemoveProjectCollaborator { + project_id: project.id.to_proto(), + peer_id: connection_id.0, + }, + ) + .trace_err(); + } + + self.peer + .send( connection_id, - proto::RemoveProjectCollaborator { + proto::UnshareProject { project_id: project.id.to_proto(), - peer_id: message.sender_connection_id.0, }, - )?; - } - - self.peer.send( - message.sender_connection_id, - proto::UnshareProject { - project_id: project.id.to_proto(), - }, - )?; + ) + .trace_err(); } } self.room_updated(&left_room.room); { let store = self.store().await; - for user_id in left_room.canceled_calls_to_user_ids { - for connection_id in store.connection_ids_for_user(user_id) { + for canceled_user_id in left_room.canceled_calls_to_user_ids { + for connection_id in store.connection_ids_for_user(canceled_user_id) { self.peer .send(connection_id, proto::CallCanceled {}) .trace_err(); } - contacts_to_update.insert(user_id); + contacts_to_update.insert(canceled_user_id); } } - self.room_left(&left_room.room, message.sender_connection_id) - .await - .trace_err(); - for user_id in contacts_to_update { - self.update_user_contacts(user_id).await?; + for contact_user_id in contacts_to_update { + self.update_user_contacts(contact_user_id).await?; + } + + if let Some(live_kit) = self.app_state.live_kit_client.as_ref() { + live_kit + .remove_participant( + left_room.room.live_kit_room.clone(), + connection_id.to_string(), + ) + .await + .trace_err(); + + if left_room.room.participants.is_empty() { + live_kit + .delete_room(left_room.room.live_kit_room) + .await + .trace_err(); + } } Ok(()) @@ -725,6 +695,7 @@ impl Server { ) -> Result<()> { let room_id = RoomId::from_proto(request.payload.room_id); let calling_user_id = request.sender_user_id; + let calling_connection_id = request.sender_connection_id; let called_user_id = UserId::from_proto(request.payload.called_user_id); let initial_project_id = request .payload @@ -742,7 +713,13 @@ impl Server { let (room, incoming_call) = self .app_state .db - .call(room_id, calling_user_id, called_user_id, initial_project_id) + .call( + room_id, + calling_user_id, + calling_connection_id, + called_user_id, + initial_project_id, + ) .await?; self.room_updated(&room); self.update_user_contacts(called_user_id).await?; @@ -838,7 +815,7 @@ impl Server { let room = self .app_state .db - .update_room_participant_location(room_id, request.sender_user_id, location) + .update_room_participant_location(room_id, request.sender_connection_id, location) .await?; self.room_updated(&room); response.send(proto::Ack {})?; @@ -858,29 +835,6 @@ impl Server { } } - fn room_left( - &self, - room: &proto::Room, - connection_id: ConnectionId, - ) -> impl Future> { - let client = self.app_state.live_kit_client.clone(); - let room_name = room.live_kit_room.clone(); - let participant_count = room.participants.len(); - async move { - if let Some(client) = client { - client - .remove_participant(room_name.clone(), connection_id.to_string()) - .await?; - - if participant_count == 0 { - client.delete_room(room_name).await?; - } - } - - Ok(()) - } - } - async fn share_project( self: Arc, request: Message, diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index de444924091d3c6ce7013b872121f85a3fc03bb4..3896b8f7a40a9f7f2e1b09d032cf8b38dcd83cce 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -3,7 +3,7 @@ use anyhow::{anyhow, Result}; use collections::{btree_map, BTreeMap, BTreeSet, HashMap, HashSet}; use rpc::{proto, ConnectionId}; use serde::Serialize; -use std::{borrow::Cow, mem, path::PathBuf, str}; +use std::{mem, path::PathBuf, str}; use tracing::instrument; pub type RoomId = u64; @@ -135,14 +135,6 @@ impl Store { Ok(()) } - pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> Result { - Ok(self - .connections - .get(&connection_id) - .ok_or_else(|| anyhow!("unknown connection"))? - .user_id) - } - pub fn connection_ids_for_user( &self, user_id: UserId, diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index c1daf758230058b53c7929e8a5cf3859b4d9f81b..a93c0b593fad871a7cf1b22ad16918e439765e14 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -158,9 +158,7 @@ message JoinRoomResponse { optional LiveKitConnectionInfo live_kit_connection_info = 2; } -message LeaveRoom { - uint64 id = 1; -} +message LeaveRoom {} message Room { uint64 id = 1; From 0310e27347cc19e45198df270842cab2b668f34b Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 14 Nov 2022 10:53:11 +0100 Subject: [PATCH 013/109] Fix query errors in `Db::share_project` --- crates/collab/src/db.rs | 11 +++++++---- crates/collab/src/rpc.rs | 3 ++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index f32bdf96eff4725f93d2154f7f9c85336fb21340..d329bf23e500615c42b927035b12de997cb1c153 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1296,10 +1296,11 @@ where SELECT projects.id, worktrees.root_name FROM projects LEFT JOIN worktrees ON projects.id = worktrees.project_id - WHERE room_id = $1 AND host_user_id = $2 + WHERE room_id = $1 AND host_connection_id = $2 ", ) .bind(room_id) + .bind(participant.peer_id as i32) .fetch(&mut *tx); let mut projects = HashMap::default(); @@ -1341,14 +1342,14 @@ where let mut tx = self.pool.begin().await?; let project_id = sqlx::query_scalar( " - INSERT INTO projects (host_user_id, host_connection_id, room_id) + INSERT INTO projects (room_id, host_user_id, host_connection_id) VALUES ($1, $2, $3) RETURNING id ", ) + .bind(room_id) .bind(user_id) .bind(connection_id.0 as i32) - .bind(room_id) .fetch_one(&mut tx) .await .map(ProjectId)?; @@ -1356,7 +1357,8 @@ where for worktree in worktrees { sqlx::query( " - INSERT INTO worktrees (id, project_id, root_name) + INSERT INTO worktrees (id, project_id, root_name) + VALUES ($1, $2, $3) ", ) .bind(worktree.id as i32) @@ -1387,6 +1389,7 @@ where .await?; let room = self.commit_room_transaction(room_id, tx).await?; + dbg!(&room); Ok((project_id, room)) }) } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index e69393c642eda26480119a99a7506efa913a7c13..038724c25a23eb353891c7e5fdb4c5aa9237361e 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -849,7 +849,8 @@ impl Server { RoomId::from_proto(request.payload.room_id), &request.payload.worktrees, ) - .await?; + .await + .unwrap(); response.send(proto::ShareProjectResponse { project_id: project_id.to_proto(), })?; From 59e8600e4c43e412f6088eb80dfe4a78f5fb3969 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 14 Nov 2022 11:12:23 +0100 Subject: [PATCH 014/109] Implement `Db::cancel_call` --- crates/collab/src/db.rs | 29 ++++++++++++++++++- crates/collab/src/rpc.rs | 44 ++++++++++++++--------------- crates/collab/src/rpc/store.rs | 51 ---------------------------------- 3 files changed, 50 insertions(+), 74 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index d329bf23e500615c42b927035b12de997cb1c153..50a333bced9aec4c5d4a761b49f4cdc084e9a711 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1048,6 +1048,30 @@ where }) } + pub async fn cancel_call( + &self, + room_id: RoomId, + calling_connection_id: ConnectionId, + called_user_id: UserId, + ) -> Result { + test_support!(self, { + let mut tx = self.pool.begin().await?; + sqlx::query( + " + DELETE FROM room_participants + WHERE room_id = $1 AND user_id = $2 AND calling_connection_id = $3 AND answering_connection_id IS NULL + ", + ) + .bind(room_id) + .bind(called_user_id) + .bind(calling_connection_id.0 as i32) + .execute(&mut tx) + .await?; + + self.commit_room_transaction(room_id, tx).await + }) + } + pub async fn join_room( &self, room_id: RoomId, @@ -1073,7 +1097,10 @@ where }) } - pub async fn leave_room(&self, connection_id: ConnectionId) -> Result> { + pub async fn leave_room_for_connection( + &self, + connection_id: ConnectionId, + ) -> Result> { test_support!(self, { let mut tx = self.pool.begin().await?; diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 038724c25a23eb353891c7e5fdb4c5aa9237361e..3e519d91aefae54ba16091c6e8bcd2f3230be9d5 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -558,13 +558,13 @@ impl Server { request.sender_connection_id, ) .await?; - for recipient_id in self + for connection_id in self .store() .await .connection_ids_for_user(request.sender_user_id) { self.peer - .send(recipient_id, proto::CallCanceled {}) + .send(connection_id, proto::CallCanceled {}) .trace_err(); } @@ -610,7 +610,7 @@ impl Server { ) -> Result<()> { let mut contacts_to_update = HashSet::default(); - let Some(left_room) = self.app_state.db.leave_room(connection_id).await? else { + let Some(left_room) = self.app_state.db.leave_room_for_connection(connection_id).await? else { return Err(anyhow!("no room to leave"))?; }; contacts_to_update.insert(user_id); @@ -751,7 +751,7 @@ impl Server { self.room_updated(&room); self.update_user_contacts(called_user_id).await?; - Err(anyhow!("failed to ring call recipient"))? + Err(anyhow!("failed to ring user"))? } async fn cancel_call( @@ -759,23 +759,23 @@ impl Server { request: Message, response: Response, ) -> Result<()> { - let recipient_user_id = UserId::from_proto(request.payload.called_user_id); - { - let mut store = self.store().await; - let (room, recipient_connection_ids) = store.cancel_call( - request.payload.room_id, - recipient_user_id, - request.sender_connection_id, - )?; - for recipient_id in recipient_connection_ids { - self.peer - .send(recipient_id, proto::CallCanceled {}) - .trace_err(); - } - self.room_updated(room); - response.send(proto::Ack {})?; + let called_user_id = UserId::from_proto(request.payload.called_user_id); + let room_id = RoomId::from_proto(request.payload.room_id); + + let room = self + .app_state + .db + .cancel_call(room_id, request.sender_connection_id, called_user_id) + .await?; + for connection_id in self.store().await.connection_ids_for_user(called_user_id) { + self.peer + .send(connection_id, proto::CallCanceled {}) + .trace_err(); } - self.update_user_contacts(recipient_user_id).await?; + self.room_updated(&room); + response.send(proto::Ack {})?; + + self.update_user_contacts(called_user_id).await?; Ok(()) } @@ -788,13 +788,13 @@ impl Server { message.sender_user_id, ) .await?; - for recipient_id in self + for connection_id in self .store() .await .connection_ids_for_user(message.sender_user_id) { self.peer - .send(recipient_id, proto::CallCanceled {}) + .send(connection_id, proto::CallCanceled {}) .trace_err(); } self.room_updated(&room); diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index 3896b8f7a40a9f7f2e1b09d032cf8b38dcd83cce..a9793e9fb67af8e97c11d79d59a3b7927d24d3cd 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -211,57 +211,6 @@ impl Store { &self.rooms } - pub fn cancel_call( - &mut self, - room_id: RoomId, - called_user_id: UserId, - canceller_connection_id: ConnectionId, - ) -> Result<(&proto::Room, HashSet)> { - todo!() - // let canceller_user_id = self.user_id_for_connection(canceller_connection_id)?; - // let canceller = self - // .connected_users - // .get(&canceller_user_id) - // .ok_or_else(|| anyhow!("no such connection"))?; - // let recipient = self - // .connected_users - // .get(&called_user_id) - // .ok_or_else(|| anyhow!("no such connection"))?; - // let canceller_active_call = canceller - // .active_call - // .as_ref() - // .ok_or_else(|| anyhow!("no active call"))?; - // let recipient_active_call = recipient - // .active_call - // .as_ref() - // .ok_or_else(|| anyhow!("no active call for recipient"))?; - - // anyhow::ensure!( - // canceller_active_call.room_id == room_id, - // "users are on different calls" - // ); - // anyhow::ensure!( - // recipient_active_call.room_id == room_id, - // "users are on different calls" - // ); - // anyhow::ensure!( - // recipient_active_call.connection_id.is_none(), - // "recipient has already answered" - // ); - // let room_id = recipient_active_call.room_id; - // let room = self - // .rooms - // .get_mut(&room_id) - // .ok_or_else(|| anyhow!("no such room"))?; - // room.pending_participant_user_ids - // .retain(|user_id| UserId::from_proto(*user_id) != called_user_id); - - // let recipient = self.connected_users.get_mut(&called_user_id).unwrap(); - // recipient.active_call.take(); - - // Ok((room, recipient.connection_ids.clone())) - } - pub fn unshare_project( &mut self, project_id: ProjectId, From 65c5adff058c757142bcd8041806015b08d114a3 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 14 Nov 2022 11:32:26 +0100 Subject: [PATCH 015/109] Automatically decline call when user drops their last connection --- crates/collab/src/db.rs | 30 ++++++++++++++++++++---------- crates/collab/src/rpc.rs | 33 +++++++++++++++++++++++++-------- 2 files changed, 45 insertions(+), 18 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 50a333bced9aec4c5d4a761b49f4cdc084e9a711..39bc2775a0fe9f6ecf615fd08c4625d9d30a4572 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1030,19 +1030,26 @@ where }) } - pub async fn decline_call(&self, room_id: RoomId, user_id: UserId) -> Result { + pub async fn decline_call( + &self, + expected_room_id: Option, + user_id: UserId, + ) -> Result { test_support!(self, { let mut tx = self.pool.begin().await?; - sqlx::query( + let room_id = sqlx::query_scalar( " DELETE FROM room_participants - WHERE room_id = $1 AND user_id = $2 AND answering_connection_id IS NULL + WHERE user_id = $1 AND answering_connection_id IS NULL + RETURNING room_id ", ) - .bind(room_id) .bind(user_id) - .execute(&mut tx) + .fetch_one(&mut tx) .await?; + if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) { + return Err(anyhow!("declining call on unexpected room"))?; + } self.commit_room_transaction(room_id, tx).await }) @@ -1050,23 +1057,26 @@ where pub async fn cancel_call( &self, - room_id: RoomId, + expected_room_id: Option, calling_connection_id: ConnectionId, called_user_id: UserId, ) -> Result { test_support!(self, { let mut tx = self.pool.begin().await?; - sqlx::query( + let room_id = sqlx::query_scalar( " DELETE FROM room_participants - WHERE room_id = $1 AND user_id = $2 AND calling_connection_id = $3 AND answering_connection_id IS NULL + WHERE user_id = $1 AND calling_connection_id = $2 AND answering_connection_id IS NULL + RETURNING room_id ", ) - .bind(room_id) .bind(called_user_id) .bind(calling_connection_id.0 as i32) - .execute(&mut tx) + .fetch_one(&mut tx) .await?; + if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) { + return Err(anyhow!("canceling call on unexpected room"))?; + } self.commit_room_transaction(room_id, tx).await }) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 3e519d91aefae54ba16091c6e8bcd2f3230be9d5..d9c8c616f38ba6228c84949b45721a7021df0cc3 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -430,9 +430,29 @@ impl Server { user_id: UserId, ) -> Result<()> { self.peer.disconnect(connection_id); - self.store().await.remove_connection(connection_id)?; + let decline_calls = { + let mut store = self.store().await; + store.remove_connection(connection_id)?; + let mut connections = store.connection_ids_for_user(user_id); + connections.next().is_none() + }; + self.leave_room_for_connection(connection_id, user_id) - .await?; + .await + .trace_err(); + if decline_calls { + if let Some(room) = self + .app_state + .db + .decline_call(None, user_id) + .await + .trace_err() + { + self.room_updated(&room); + } + } + + self.update_user_contacts(user_id).await?; Ok(()) } @@ -761,11 +781,10 @@ impl Server { ) -> Result<()> { let called_user_id = UserId::from_proto(request.payload.called_user_id); let room_id = RoomId::from_proto(request.payload.room_id); - let room = self .app_state .db - .cancel_call(room_id, request.sender_connection_id, called_user_id) + .cancel_call(Some(room_id), request.sender_connection_id, called_user_id) .await?; for connection_id in self.store().await.connection_ids_for_user(called_user_id) { self.peer @@ -780,13 +799,11 @@ impl Server { } async fn decline_call(self: Arc, message: Message) -> Result<()> { + let room_id = RoomId::from_proto(message.payload.room_id); let room = self .app_state .db - .decline_call( - RoomId::from_proto(message.payload.room_id), - message.sender_user_id, - ) + .decline_call(Some(room_id), message.sender_user_id) .await?; for connection_id in self .store() From 40073f6100acff471c903c011695117f9751a3d1 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 14 Nov 2022 15:32:49 +0100 Subject: [PATCH 016/109] Wait for acknowledgment before sending the next project update --- crates/call/src/room.rs | 2 + crates/collab/src/db.rs | 16 +- crates/collab/src/integration_tests.rs | 12 +- crates/collab/src/rpc.rs | 6 +- crates/collab_ui/src/collab_ui.rs | 1 - crates/project/src/project.rs | 218 ++++++++----------------- crates/project/src/worktree.rs | 19 --- crates/rpc/proto/zed.proto | 7 +- crates/rpc/src/proto.rs | 9 +- crates/workspace/src/workspace.rs | 8 +- crates/zed/src/main.rs | 4 +- 11 files changed, 94 insertions(+), 208 deletions(-) diff --git a/crates/call/src/room.rs b/crates/call/src/room.rs index 0ecd6082d63f576be7c4f3342aa679921600e873..c1b0dc191d07bb4cbb6e83ab3239260ca0e0edb1 100644 --- a/crates/call/src/room.rs +++ b/crates/call/src/room.rs @@ -287,6 +287,8 @@ impl Room { mut room: proto::Room, cx: &mut ModelContext, ) -> Result<()> { + // TODO: honor room version. + // Filter ourselves out from the room's participants. let local_participant_ix = room .participants diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 39bc2775a0fe9f6ecf615fd08c4625d9d30a4572..a12985b94bef8d71dd25db32845d82866ddd33e0 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1145,8 +1145,8 @@ where FROM projects, project_collaborators WHERE projects.room_id = $1 AND - projects.host_connection_id = $2 AND - projects.id = project_collaborators.project_id + projects.id = project_collaborators.project_id AND + project_collaborators.connection_id = $2 ", ) .bind(room_id) @@ -1370,9 +1370,9 @@ where pub async fn share_project( &self, + room_id: RoomId, user_id: UserId, connection_id: ConnectionId, - room_id: RoomId, worktrees: &[proto::WorktreeMetadata], ) -> Result<(ProjectId, proto::Room)> { test_support!(self, { @@ -1426,11 +1426,19 @@ where .await?; let room = self.commit_room_transaction(room_id, tx).await?; - dbg!(&room); Ok((project_id, room)) }) } + // pub async fn join_project( + // &self, + // user_id: UserId, + // connection_id: ConnectionId, + // project_id: ProjectId, + // ) -> Result<(Project, ReplicaId)> { + // todo!() + // } + pub async fn unshare_project(&self, project_id: ProjectId) -> Result<()> { todo!() // test_support!(self, { diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index 3a4c2368e8060994482b464dfbe7081a080e3efc..b54f03ce53e0aa6200814f8db9f1fc67744b718a 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -30,9 +30,7 @@ use language::{ use live_kit_client::MacOSDisplay; use lsp::{self, FakeLanguageServer}; use parking_lot::Mutex; -use project::{ - search::SearchQuery, DiagnosticSummary, Project, ProjectPath, ProjectStore, WorktreeId, -}; +use project::{search::SearchQuery, DiagnosticSummary, Project, ProjectPath, WorktreeId}; use rand::prelude::*; use serde_json::json; use settings::{Formatter, Settings}; @@ -2280,7 +2278,6 @@ async fn test_leaving_project( project_id, client_b.client.clone(), client_b.user_store.clone(), - client_b.project_store.clone(), client_b.language_registry.clone(), FakeFs::new(cx.background()), cx, @@ -5792,11 +5789,9 @@ impl TestServer { let fs = FakeFs::new(cx.background()); let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx)); - let project_store = cx.add_model(|_| ProjectStore::new()); let app_state = Arc::new(workspace::AppState { client: client.clone(), user_store: user_store.clone(), - project_store: project_store.clone(), languages: Arc::new(LanguageRegistry::new(Task::ready(()))), themes: ThemeRegistry::new((), cx.font_cache()), fs: fs.clone(), @@ -5823,7 +5818,6 @@ impl TestServer { remote_projects: Default::default(), next_root_dir_id: 0, user_store, - project_store, fs, language_registry: Arc::new(LanguageRegistry::test()), buffers: Default::default(), @@ -5929,7 +5923,6 @@ struct TestClient { remote_projects: Vec>, next_root_dir_id: usize, pub user_store: ModelHandle, - pub project_store: ModelHandle, language_registry: Arc, fs: Arc, buffers: HashMap, HashSet>>, @@ -5999,7 +5992,6 @@ impl TestClient { Project::local( self.client.clone(), self.user_store.clone(), - self.project_store.clone(), self.language_registry.clone(), self.fs.clone(), cx, @@ -6027,7 +6019,6 @@ impl TestClient { host_project_id, self.client.clone(), self.user_store.clone(), - self.project_store.clone(), self.language_registry.clone(), FakeFs::new(cx.background()), cx, @@ -6157,7 +6148,6 @@ impl TestClient { remote_project_id, client.client.clone(), client.user_store.clone(), - client.project_store.clone(), client.language_registry.clone(), FakeFs::new(cx.background()), cx.to_async(), diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index d9c8c616f38ba6228c84949b45721a7021df0cc3..bed6ebf9cd649a10bb9bce2d931606e2a56ed281 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -151,7 +151,7 @@ impl Server { .add_message_handler(Server::unshare_project) .add_request_handler(Server::join_project) .add_message_handler(Server::leave_project) - .add_message_handler(Server::update_project) + .add_request_handler(Server::update_project) .add_request_handler(Server::update_worktree) .add_message_handler(Server::start_language_server) .add_message_handler(Server::update_language_server) @@ -861,9 +861,9 @@ impl Server { .app_state .db .share_project( + RoomId::from_proto(request.payload.room_id), request.sender_user_id, request.sender_connection_id, - RoomId::from_proto(request.payload.room_id), &request.payload.worktrees, ) .await @@ -1084,6 +1084,7 @@ impl Server { async fn update_project( self: Arc, request: Message, + response: Response, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); { @@ -1108,6 +1109,7 @@ impl Server { }, ); self.room_updated(room); + response.send(proto::Ack {})?; }; Ok(()) diff --git a/crates/collab_ui/src/collab_ui.rs b/crates/collab_ui/src/collab_ui.rs index f5f508ce5b167059cf8c3fbaebcb0e1d5e80b996..dc8a1716989e4b5ccec4ae476591b3f8f41339f3 100644 --- a/crates/collab_ui/src/collab_ui.rs +++ b/crates/collab_ui/src/collab_ui.rs @@ -43,7 +43,6 @@ pub fn init(app_state: Arc, cx: &mut MutableAppContext) { project_id, app_state.client.clone(), app_state.user_store.clone(), - app_state.project_store.clone(), app_state.languages.clone(), app_state.fs.clone(), cx.clone(), diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 3c28f6b512e38f9f803398641b7b9676beec234f..d01571f44b1f7df78698420af0f19e282a4d8c55 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -70,10 +70,6 @@ pub trait Item: Entity { fn entry_id(&self, cx: &AppContext) -> Option; } -pub struct ProjectStore { - projects: Vec>, -} - // Language server state is stored across 3 collections: // language_servers => // a mapping from unique server id to LanguageServerState which can either be a task for a @@ -102,7 +98,6 @@ pub struct Project { next_entry_id: Arc, next_diagnostic_group_id: usize, user_store: ModelHandle, - project_store: ModelHandle, fs: Arc, client_state: Option, collaborators: HashMap, @@ -152,6 +147,8 @@ enum WorktreeHandle { enum ProjectClientState { Local { remote_id: u64, + metadata_changed: watch::Sender<()>, + _maintain_metadata: Task<()>, _detect_unshare: Task>, }, Remote { @@ -376,7 +373,7 @@ impl Project { client.add_model_message_handler(Self::handle_start_language_server); client.add_model_message_handler(Self::handle_update_language_server); client.add_model_message_handler(Self::handle_remove_collaborator); - client.add_model_message_handler(Self::handle_update_project); + client.add_model_message_handler(Self::handle_project_updated); client.add_model_message_handler(Self::handle_unshare_project); client.add_model_message_handler(Self::handle_create_buffer_for_peer); client.add_model_message_handler(Self::handle_update_buffer_file); @@ -412,46 +409,39 @@ impl Project { pub fn local( client: Arc, user_store: ModelHandle, - project_store: ModelHandle, languages: Arc, fs: Arc, cx: &mut MutableAppContext, ) -> ModelHandle { - cx.add_model(|cx: &mut ModelContext| { - let handle = cx.weak_handle(); - project_store.update(cx, |store, cx| store.add_project(handle, cx)); - - Self { - worktrees: Default::default(), - collaborators: Default::default(), - opened_buffers: Default::default(), - shared_buffers: Default::default(), - incomplete_buffers: Default::default(), - loading_buffers: Default::default(), - loading_local_worktrees: Default::default(), - buffer_snapshots: Default::default(), - client_state: None, - opened_buffer: watch::channel(), - client_subscriptions: Vec::new(), - _subscriptions: vec![cx.observe_global::(Self::on_settings_changed)], - _maintain_buffer_languages: Self::maintain_buffer_languages(&languages, cx), - active_entry: None, - languages, - client, - user_store, - project_store, - fs, - next_entry_id: Default::default(), - next_diagnostic_group_id: Default::default(), - language_servers: Default::default(), - language_server_ids: Default::default(), - language_server_statuses: Default::default(), - last_workspace_edits_by_language_server: Default::default(), - language_server_settings: Default::default(), - buffers_being_formatted: Default::default(), - next_language_server_id: 0, - nonce: StdRng::from_entropy().gen(), - } + cx.add_model(|cx: &mut ModelContext| Self { + worktrees: Default::default(), + collaborators: Default::default(), + opened_buffers: Default::default(), + shared_buffers: Default::default(), + incomplete_buffers: Default::default(), + loading_buffers: Default::default(), + loading_local_worktrees: Default::default(), + buffer_snapshots: Default::default(), + client_state: None, + opened_buffer: watch::channel(), + client_subscriptions: Vec::new(), + _subscriptions: vec![cx.observe_global::(Self::on_settings_changed)], + _maintain_buffer_languages: Self::maintain_buffer_languages(&languages, cx), + active_entry: None, + languages, + client, + user_store, + fs, + next_entry_id: Default::default(), + next_diagnostic_group_id: Default::default(), + language_servers: Default::default(), + language_server_ids: Default::default(), + language_server_statuses: Default::default(), + last_workspace_edits_by_language_server: Default::default(), + language_server_settings: Default::default(), + buffers_being_formatted: Default::default(), + next_language_server_id: 0, + nonce: StdRng::from_entropy().gen(), }) } @@ -459,7 +449,6 @@ impl Project { remote_id: u64, client: Arc, user_store: ModelHandle, - project_store: ModelHandle, languages: Arc, fs: Arc, mut cx: AsyncAppContext, @@ -482,9 +471,6 @@ impl Project { } let this = cx.add_model(|cx: &mut ModelContext| { - let handle = cx.weak_handle(); - project_store.update(cx, |store, cx| store.add_project(handle, cx)); - let mut this = Self { worktrees: Vec::new(), loading_buffers: Default::default(), @@ -497,7 +483,6 @@ impl Project { _maintain_buffer_languages: Self::maintain_buffer_languages(&languages, cx), languages, user_store: user_store.clone(), - project_store, fs, next_entry_id: Default::default(), next_diagnostic_group_id: Default::default(), @@ -593,9 +578,7 @@ impl Project { let http_client = client::test::FakeHttpClient::with_404_response(); let client = cx.update(|cx| client::Client::new(http_client.clone(), cx)); let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http_client, cx)); - let project_store = cx.add_model(|_| ProjectStore::new()); - let project = - cx.update(|cx| Project::local(client, user_store, project_store, languages, fs, cx)); + let project = cx.update(|cx| Project::local(client, user_store, languages, fs, cx)); for path in root_paths { let (tree, _) = project .update(cx, |project, cx| { @@ -676,10 +659,6 @@ impl Project { self.user_store.clone() } - pub fn project_store(&self) -> ModelHandle { - self.project_store.clone() - } - #[cfg(any(test, feature = "test-support"))] pub fn check_invariants(&self, cx: &AppContext) { if self.is_local() { @@ -752,51 +731,12 @@ impl Project { } fn metadata_changed(&mut self, cx: &mut ModelContext) { - if let Some(ProjectClientState::Local { remote_id, .. }) = &self.client_state { - let project_id = *remote_id; - // Broadcast worktrees only if the project is online. - let worktrees = self - .worktrees - .iter() - .filter_map(|worktree| { - worktree - .upgrade(cx) - .map(|worktree| worktree.read(cx).as_local().unwrap().metadata_proto()) - }) - .collect(); - self.client - .send(proto::UpdateProject { - project_id, - worktrees, - }) - .log_err(); - - let worktrees = self.visible_worktrees(cx).collect::>(); - let scans_complete = futures::future::join_all( - worktrees - .iter() - .filter_map(|worktree| Some(worktree.read(cx).as_local()?.scan_complete())), - ); - - let worktrees = worktrees.into_iter().map(|handle| handle.downgrade()); - - cx.spawn_weak(move |_, cx| async move { - scans_complete.await; - cx.read(|cx| { - for worktree in worktrees { - if let Some(worktree) = worktree - .upgrade(cx) - .and_then(|worktree| worktree.read(cx).as_local()) - { - worktree.send_extension_counts(project_id); - } - } - }) - }) - .detach(); + if let Some(ProjectClientState::Local { + metadata_changed, .. + }) = &mut self.client_state + { + *metadata_changed.borrow_mut() = (); } - - self.project_store.update(cx, |_, cx| cx.notify()); cx.notify(); } @@ -1092,8 +1032,32 @@ impl Project { cx.notify(); let mut status = self.client.status(); + let (metadata_changed_tx, mut metadata_changed_rx) = watch::channel(); self.client_state = Some(ProjectClientState::Local { remote_id: project_id, + metadata_changed: metadata_changed_tx, + _maintain_metadata: cx.spawn_weak(move |this, cx| async move { + while let Some(()) = metadata_changed_rx.next().await { + let Some(this) = this.upgrade(&cx) else { break }; + this.read_with(&cx, |this, cx| { + let worktrees = this + .worktrees + .iter() + .filter_map(|worktree| { + worktree.upgrade(cx).map(|worktree| { + worktree.read(cx).as_local().unwrap().metadata_proto() + }) + }) + .collect(); + this.client.request(proto::UpdateProject { + project_id, + worktrees, + }) + }) + .await + .log_err(); + } + }), _detect_unshare: cx.spawn_weak(move |this, mut cx| { async move { let is_connected = status.next().await.map_or(false, |s| s.is_connected()); @@ -1632,10 +1596,6 @@ impl Project { operations: vec![language::proto::serialize_operation(operation)], }); cx.background().spawn(request).detach_and_log_err(cx); - } else if let Some(project_id) = self.remote_id() { - let _ = self - .client - .send(proto::RegisterProjectActivity { project_id }); } } BufferEvent::Edited { .. } => { @@ -4573,9 +4533,9 @@ impl Project { }) } - async fn handle_update_project( + async fn handle_project_updated( this: ModelHandle, - envelope: TypedEnvelope, + envelope: TypedEnvelope, client: Arc, mut cx: AsyncAppContext, ) -> Result<()> { @@ -5832,48 +5792,6 @@ impl Project { } } -impl ProjectStore { - pub fn new() -> Self { - Self { - projects: Default::default(), - } - } - - pub fn projects<'a>( - &'a self, - cx: &'a AppContext, - ) -> impl 'a + Iterator> { - self.projects - .iter() - .filter_map(|project| project.upgrade(cx)) - } - - fn add_project(&mut self, project: WeakModelHandle, cx: &mut ModelContext) { - if let Err(ix) = self - .projects - .binary_search_by_key(&project.id(), WeakModelHandle::id) - { - self.projects.insert(ix, project); - } - cx.notify(); - } - - fn prune_projects(&mut self, cx: &mut ModelContext) { - let mut did_change = false; - self.projects.retain(|project| { - if project.is_upgradable(cx) { - true - } else { - did_change = true; - false - } - }); - if did_change { - cx.notify(); - } - } -} - impl WorktreeHandle { pub fn upgrade(&self, cx: &AppContext) -> Option> { match self { @@ -5952,16 +5870,10 @@ impl<'a> Iterator for PathMatchCandidateSetIter<'a> { } } -impl Entity for ProjectStore { - type Event = (); -} - impl Entity for Project { type Event = Event; - fn release(&mut self, cx: &mut gpui::MutableAppContext) { - self.project_store.update(cx, ProjectStore::prune_projects); - + fn release(&mut self, _: &mut gpui::MutableAppContext) { match &self.client_state { Some(ProjectClientState::Local { remote_id, .. }) => { self.client diff --git a/crates/project/src/worktree.rs b/crates/project/src/worktree.rs index db8fb8e3ff8e4cea90a33c805311d4032302a890..9e4ec3ffb9a236e8b9b13c871269833e225fa1b3 100644 --- a/crates/project/src/worktree.rs +++ b/crates/project/src/worktree.rs @@ -1051,25 +1051,6 @@ impl LocalWorktree { pub fn is_shared(&self) -> bool { self.share.is_some() } - - pub fn send_extension_counts(&self, project_id: u64) { - let mut extensions = Vec::new(); - let mut counts = Vec::new(); - - for (extension, count) in self.extension_counts() { - extensions.push(extension.to_string_lossy().to_string()); - counts.push(*count as u32); - } - - self.client - .send(proto::UpdateWorktreeExtensions { - project_id, - worktree_id: self.id().to_proto(), - extensions, - counts, - }) - .log_err(); - } } impl RemoteWorktree { diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index a93c0b593fad871a7cf1b22ad16918e439765e14..94880ce9f56e80e5b677eafa65878295a5b424e7 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -48,9 +48,8 @@ message Envelope { OpenBufferForSymbolResponse open_buffer_for_symbol_response = 40; UpdateProject update_project = 41; - RegisterProjectActivity register_project_activity = 42; + ProjectUpdated project_updated = 42; UpdateWorktree update_worktree = 43; - UpdateWorktreeExtensions update_worktree_extensions = 44; CreateProjectEntry create_project_entry = 45; RenameProjectEntry rename_project_entry = 46; @@ -258,8 +257,10 @@ message UpdateProject { repeated WorktreeMetadata worktrees = 2; } -message RegisterProjectActivity { +message ProjectUpdated { uint64 project_id = 1; + repeated WorktreeMetadata worktrees = 2; + uint64 room_version = 3; } message JoinProject { diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index 11bbaaf5ffcbdeff96906033b1cacae6f62e48f0..31f53564a8b9d99c7bba00de3de969f95cfc1498 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -140,12 +140,12 @@ messages!( (OpenBufferResponse, Background), (PerformRename, Background), (PerformRenameResponse, Background), + (Ping, Foreground), (PrepareRename, Background), (PrepareRenameResponse, Background), (ProjectEntryResponse, Foreground), + (ProjectUpdated, Foreground), (RemoveContact, Foreground), - (Ping, Foreground), - (RegisterProjectActivity, Foreground), (ReloadBuffers, Foreground), (ReloadBuffersResponse, Foreground), (RemoveProjectCollaborator, Foreground), @@ -175,7 +175,6 @@ messages!( (UpdateParticipantLocation, Foreground), (UpdateProject, Foreground), (UpdateWorktree, Foreground), - (UpdateWorktreeExtensions, Background), (UpdateDiffBase, Background), (GetPrivateUserInfo, Foreground), (GetPrivateUserInfoResponse, Foreground), @@ -231,6 +230,7 @@ request_messages!( (Test, Test), (UpdateBuffer, Ack), (UpdateParticipantLocation, Ack), + (UpdateProject, Ack), (UpdateWorktree, Ack), ); @@ -261,8 +261,8 @@ entity_messages!( OpenBufferByPath, OpenBufferForSymbol, PerformRename, + ProjectUpdated, PrepareRename, - RegisterProjectActivity, ReloadBuffers, RemoveProjectCollaborator, RenameProjectEntry, @@ -278,7 +278,6 @@ entity_messages!( UpdateLanguageServer, UpdateProject, UpdateWorktree, - UpdateWorktreeExtensions, UpdateDiffBase ); diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index 2dbf923484e64fca19cdc885d368737377a543d3..9db524ee9ba2b935d817ce64081d8ee374bb363a 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -33,7 +33,7 @@ use log::{error, warn}; pub use pane::*; pub use pane_group::*; use postage::prelude::Stream; -use project::{Project, ProjectEntryId, ProjectPath, ProjectStore, Worktree, WorktreeId}; +use project::{Project, ProjectEntryId, ProjectPath, Worktree, WorktreeId}; use searchable::SearchableItemHandle; use serde::Deserialize; use settings::{Autosave, DockAnchor, Settings}; @@ -337,7 +337,6 @@ pub struct AppState { pub themes: Arc, pub client: Arc, pub user_store: ModelHandle, - pub project_store: ModelHandle, pub fs: Arc, pub build_window_options: fn() -> WindowOptions<'static>, pub initialize_workspace: fn(&mut Workspace, &Arc, &mut ViewContext), @@ -1039,7 +1038,6 @@ impl AppState { let languages = Arc::new(LanguageRegistry::test()); let http_client = client::test::FakeHttpClient::with_404_response(); let client = Client::new(http_client.clone(), cx); - let project_store = cx.add_model(|_| ProjectStore::new()); let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http_client, cx)); let themes = ThemeRegistry::new((), cx.font_cache().clone()); Arc::new(Self { @@ -1048,7 +1046,6 @@ impl AppState { fs, languages, user_store, - project_store, initialize_workspace: |_, _, _| {}, build_window_options: Default::default, default_item_factory: |_, _| unimplemented!(), @@ -1301,7 +1298,6 @@ impl Workspace { Project::local( app_state.client.clone(), app_state.user_store.clone(), - app_state.project_store.clone(), app_state.languages.clone(), app_state.fs.clone(), cx, @@ -2965,7 +2961,6 @@ pub fn open_paths( let project = Project::local( app_state.client.clone(), app_state.user_store.clone(), - app_state.project_store.clone(), app_state.languages.clone(), app_state.fs.clone(), cx, @@ -2997,7 +2992,6 @@ fn open_new(app_state: &Arc, cx: &mut MutableAppContext) { Project::local( app_state.client.clone(), app_state.user_store.clone(), - app_state.project_store.clone(), app_state.languages.clone(), app_state.fs.clone(), cx, diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index e849632a2df38945fcf34bf8b5967491f19df9e9..5a7ee2dbaee735ebf1132242aba7ef7fa674424f 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -23,7 +23,7 @@ use isahc::{config::Configurable, Request}; use language::LanguageRegistry; use log::LevelFilter; use parking_lot::Mutex; -use project::{Fs, HomeDir, ProjectStore}; +use project::{Fs, HomeDir}; use serde_json::json; use settings::{ self, settings_file::SettingsFile, KeymapFileContent, Settings, SettingsFileContent, @@ -146,7 +146,6 @@ fn main() { }) .detach(); - let project_store = cx.add_model(|_| ProjectStore::new()); let db = cx.background().block(db); client.start_telemetry(db.clone()); client.report_event("start app", Default::default()); @@ -156,7 +155,6 @@ fn main() { themes, client: client.clone(), user_store, - project_store, fs, build_window_options, initialize_workspace, From d7369ace6a2e911464c9d2099258203823934586 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 14 Nov 2022 15:35:39 +0100 Subject: [PATCH 017/109] Skip applying room updates if they're older than the local room state --- crates/call/src/room.rs | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/crates/call/src/room.rs b/crates/call/src/room.rs index c1b0dc191d07bb4cbb6e83ab3239260ca0e0edb1..4ba8d8effc4831599bb0e358a37fe535b3220f16 100644 --- a/crates/call/src/room.rs +++ b/crates/call/src/room.rs @@ -34,6 +34,7 @@ pub enum Event { pub struct Room { id: u64, + version: u64, live_kit: Option, status: RoomStatus, local_participant: LocalParticipant, @@ -61,6 +62,7 @@ impl Entity for Room { impl Room { fn new( id: u64, + version: u64, live_kit_connection_info: Option, client: Arc, user_store: ModelHandle, @@ -133,6 +135,7 @@ impl Room { Self { id, + version, live_kit: live_kit_room, status: RoomStatus::Online, participant_user_ids: Default::default(), @@ -161,6 +164,7 @@ impl Room { let room = cx.add_model(|cx| { Self::new( room_proto.id, + room_proto.version, response.live_kit_connection_info, client, user_store, @@ -205,6 +209,7 @@ impl Room { let room = cx.add_model(|cx| { Self::new( room_id, + 0, response.live_kit_connection_info, client, user_store, @@ -287,8 +292,6 @@ impl Room { mut room: proto::Room, cx: &mut ModelContext, ) -> Result<()> { - // TODO: honor room version. - // Filter ourselves out from the room's participants. let local_participant_ix = room .participants @@ -318,6 +321,10 @@ impl Room { futures::join!(remote_participants, pending_participants); this.update(&mut cx, |this, cx| { + if this.version >= room.version { + return; + } + this.participant_user_ids.clear(); if let Some(participant) = local_participant { @@ -422,6 +429,7 @@ impl Room { let _ = this.leave(cx); } + this.version = room.version; this.check_invariants(); cx.notify(); }); From b9af2ae66e31b6caa81de664ceb8d37e552d4599 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 14 Nov 2022 17:16:50 +0100 Subject: [PATCH 018/109] Switch to serializable isolation Co-Authored-By: Nathan Sobo --- .../20221109000000_test_schema.sql | 1 + crates/collab/src/db.rs | 404 ++++++++++++------ crates/collab/src/lib.rs | 8 +- crates/collab/src/rpc.rs | 45 +- 4 files changed, 298 insertions(+), 160 deletions(-) diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 2cef514e5a1810026d62047cef6b61c817e33155..d262d6a8bd414a40cc71cc56690b3232e8eaaa81 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -56,6 +56,7 @@ CREATE TABLE "project_collaborators" ( "is_host" BOOLEAN NOT NULL ); CREATE INDEX "index_project_collaborators_on_project_id" ON "project_collaborators" ("project_id"); +CREATE UNIQUE INDEX "index_project_collaborators_on_project_id" ON "project_collaborators" ("project_id", "replica_id"); CREATE TABLE "worktrees" ( "id" INTEGER NOT NULL, diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index a12985b94bef8d71dd25db32845d82866ddd33e0..b561ba045d1344bdce46f888df0614639a82dd8d 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -2,7 +2,7 @@ use crate::{Error, Result}; use anyhow::anyhow; use axum::http::StatusCode; use collections::HashMap; -use futures::StreamExt; +use futures::{future::BoxFuture, FutureExt, StreamExt}; use rpc::{proto, ConnectionId}; use serde::{Deserialize, Serialize}; use sqlx::{ @@ -10,7 +10,7 @@ use sqlx::{ types::Uuid, FromRow, }; -use std::{path::Path, time::Duration}; +use std::{future::Future, path::Path, time::Duration}; use time::{OffsetDateTime, PrimitiveDateTime}; #[cfg(test)] @@ -27,27 +27,34 @@ pub struct Db { runtime: Option, } -macro_rules! test_support { - ($self:ident, { $($token:tt)* }) => {{ - let body = async { - $($token)* - }; +pub trait BeginTransaction: Send + Sync { + type Database: sqlx::Database; - if cfg!(test) { - #[cfg(not(test))] - unreachable!(); + fn begin_transaction(&self) -> BoxFuture>>; +} - #[cfg(test)] - if let Some(background) = $self.background.as_ref() { - background.simulate_random_delay().await; - } +// In Postgres, serializable transactions are opt-in +impl BeginTransaction for Db { + type Database = sqlx::Postgres; - #[cfg(test)] - $self.runtime.as_ref().unwrap().block_on(body) - } else { - body.await + fn begin_transaction(&self) -> BoxFuture>> { + async move { + let mut tx = self.pool.begin().await?; + sqlx::Executor::execute(&mut tx, "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;") + .await?; + Ok(tx) } - }}; + .boxed() + } +} + +// In Sqlite, transactions are inherently serializable. +impl BeginTransaction for Db { + type Database = sqlx::Sqlite; + + fn begin_transaction(&self) -> BoxFuture>> { + async move { Ok(self.pool.begin().await?) }.boxed() + } } pub trait RowsAffected { @@ -88,7 +95,8 @@ impl Db { } pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { - test_support!(self, { + self.transact(|tx| async { + let mut tx = tx; let query = " SELECT users.* FROM users @@ -96,13 +104,14 @@ impl Db { "; Ok(sqlx::query_as(query) .bind(&serde_json::json!(ids)) - .fetch_all(&self.pool) + .fetch_all(&mut tx) .await?) }) + .await } pub async fn get_user_metrics_id(&self, id: UserId) -> Result { - test_support!(self, { + self.transact(|mut tx| async move { let query = " SELECT metrics_id FROM users @@ -110,9 +119,10 @@ impl Db { "; Ok(sqlx::query_scalar(query) .bind(id) - .fetch_one(&self.pool) + .fetch_one(&mut tx) .await?) }) + .await } pub async fn create_user( @@ -121,7 +131,7 @@ impl Db { admin: bool, params: NewUserParams, ) -> Result { - test_support!(self, { + self.transact(|mut tx| async { let query = " INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id) VALUES ($1, $2, $3, $4, $5) @@ -131,12 +141,13 @@ impl Db { let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query) .bind(email_address) - .bind(params.github_login) - .bind(params.github_user_id) + .bind(¶ms.github_login) + .bind(¶ms.github_user_id) .bind(admin) .bind(Uuid::new_v4().to_string()) - .fetch_one(&self.pool) + .fetch_one(&mut tx) .await?; + tx.commit().await?; Ok(NewUserResult { user_id, metrics_id, @@ -144,6 +155,7 @@ impl Db { inviting_user_id: None, }) }) + .await } pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result> { @@ -209,7 +221,8 @@ impl Db { } pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result> { - test_support!(self, { + self.transact(|tx| async { + let mut tx = tx; let like_string = Self::fuzzy_like_string(name_query); let query = " SELECT users.* @@ -222,27 +235,28 @@ impl Db { .bind(like_string) .bind(name_query) .bind(limit as i32) - .fetch_all(&self.pool) + .fetch_all(&mut tx) .await?) }) + .await } pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { - test_support!(self, { + let ids = ids.iter().map(|id| id.0).collect::>(); + self.transact(|tx| async { + let mut tx = tx; let query = " SELECT users.* FROM users WHERE users.id = ANY ($1) "; - Ok(sqlx::query_as(query) - .bind(&ids.into_iter().map(|id| id.0).collect::>()) - .fetch_all(&self.pool) - .await?) + Ok(sqlx::query_as(query).bind(&ids).fetch_all(&mut tx).await?) }) + .await } pub async fn get_user_metrics_id(&self, id: UserId) -> Result { - test_support!(self, { + self.transact(|mut tx| async move { let query = " SELECT metrics_id::text FROM users @@ -250,9 +264,10 @@ impl Db { "; Ok(sqlx::query_scalar(query) .bind(id) - .fetch_one(&self.pool) + .fetch_one(&mut tx) .await?) }) + .await } pub async fn create_user( @@ -261,7 +276,7 @@ impl Db { admin: bool, params: NewUserParams, ) -> Result { - test_support!(self, { + self.transact(|mut tx| async { let query = " INSERT INTO users (email_address, github_login, github_user_id, admin) VALUES ($1, $2, $3, $4) @@ -271,11 +286,13 @@ impl Db { let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query) .bind(email_address) - .bind(params.github_login) + .bind(¶ms.github_login) .bind(params.github_user_id) .bind(admin) - .fetch_one(&self.pool) + .fetch_one(&mut tx) .await?; + tx.commit().await?; + Ok(NewUserResult { user_id, metrics_id, @@ -283,6 +300,7 @@ impl Db { inviting_user_id: None, }) }) + .await } pub async fn create_user_from_invite( @@ -290,9 +308,7 @@ impl Db { invite: &Invite, user: NewUserParams, ) -> Result> { - test_support!(self, { - let mut tx = self.pool.begin().await?; - + self.transact(|mut tx| async { let (signup_id, existing_user_id, inviting_user_id, signup_device_id): ( i32, Option, @@ -393,10 +409,11 @@ impl Db { signup_device_id, })) }) + .await } pub async fn create_signup(&self, signup: Signup) -> Result<()> { - test_support!(self, { + self.transact(|mut tx| async { sqlx::query( " INSERT INTO signups @@ -425,10 +442,12 @@ impl Db { .bind(&signup.editor_features) .bind(&signup.programming_languages) .bind(&signup.device_id) - .execute(&self.pool) + .execute(&mut tx) .await?; + tx.commit().await?; Ok(()) }) + .await } pub async fn create_invite_from_code( @@ -437,9 +456,7 @@ impl Db { email_address: &str, device_id: Option<&str>, ) -> Result { - test_support!(self, { - let mut tx = self.pool.begin().await?; - + self.transact(|mut tx| async { let existing_user: Option = sqlx::query_scalar( " SELECT id @@ -516,10 +533,11 @@ impl Db { email_confirmation_code, }) }) + .await } pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> { - test_support!(self, { + self.transact(|mut tx| async { let emails = invites .iter() .map(|s| s.email_address.as_str()) @@ -532,15 +550,18 @@ impl Db { ", ) .bind(&emails) - .execute(&self.pool) + .execute(&mut tx) .await?; + tx.commit().await?; Ok(()) }) + .await } } impl Db where + Self: BeginTransaction, D: sqlx::Database + sqlx::migrate::MigrateDatabase, D::Connection: sqlx::migrate::Migrate, for<'a> >::Arguments: sqlx::IntoArguments<'a, D>, @@ -627,18 +648,21 @@ where // users pub async fn get_all_users(&self, page: u32, limit: u32) -> Result> { - test_support!(self, { + self.transact(|tx| async { + let mut tx = tx; let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2"; Ok(sqlx::query_as(query) .bind(limit as i32) .bind((page * limit) as i32) - .fetch_all(&self.pool) + .fetch_all(&mut tx) .await?) }) + .await } pub async fn get_user_by_id(&self, id: UserId) -> Result> { - test_support!(self, { + self.transact(|tx| async { + let mut tx = tx; let query = " SELECT users.* FROM users @@ -647,16 +671,18 @@ where "; Ok(sqlx::query_as(query) .bind(&id) - .fetch_optional(&self.pool) + .fetch_optional(&mut tx) .await?) }) + .await } pub async fn get_users_with_no_invites( &self, invited_by_another_user: bool, ) -> Result> { - test_support!(self, { + self.transact(|tx| async { + let mut tx = tx; let query = format!( " SELECT users.* @@ -667,8 +693,9 @@ where if invited_by_another_user { " NOT" } else { "" } ); - Ok(sqlx::query_as(&query).fetch_all(&self.pool).await?) + Ok(sqlx::query_as(&query).fetch_all(&mut tx).await?) }) + .await } pub async fn get_user_by_github_account( @@ -676,7 +703,8 @@ where github_login: &str, github_user_id: Option, ) -> Result> { - test_support!(self, { + self.transact(|tx| async { + let mut tx = tx; if let Some(github_user_id) = github_user_id { let mut user = sqlx::query_as::<_, User>( " @@ -688,7 +716,7 @@ where ) .bind(github_login) .bind(github_user_id) - .fetch_optional(&self.pool) + .fetch_optional(&mut tx) .await?; if user.is_none() { @@ -702,7 +730,7 @@ where ) .bind(github_user_id) .bind(github_login) - .fetch_optional(&self.pool) + .fetch_optional(&mut tx) .await?; } @@ -716,58 +744,62 @@ where ", ) .bind(github_login) - .fetch_optional(&self.pool) + .fetch_optional(&mut tx) .await?; Ok(user) } }) + .await } pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> { - test_support!(self, { + self.transact(|mut tx| async { let query = "UPDATE users SET admin = $1 WHERE id = $2"; - Ok(sqlx::query(query) + sqlx::query(query) .bind(is_admin) .bind(id.0) - .execute(&self.pool) - .await - .map(drop)?) + .execute(&mut tx) + .await?; + tx.commit().await?; + Ok(()) }) + .await } pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> { - test_support!(self, { + self.transact(|mut tx| async move { let query = "UPDATE users SET connected_once = $1 WHERE id = $2"; - Ok(sqlx::query(query) + sqlx::query(query) .bind(connected_once) .bind(id.0) - .execute(&self.pool) - .await - .map(drop)?) + .execute(&mut tx) + .await?; + tx.commit().await?; + Ok(()) }) + .await } pub async fn destroy_user(&self, id: UserId) -> Result<()> { - test_support!(self, { + self.transact(|mut tx| async move { let query = "DELETE FROM access_tokens WHERE user_id = $1;"; sqlx::query(query) .bind(id.0) - .execute(&self.pool) + .execute(&mut tx) .await .map(drop)?; let query = "DELETE FROM users WHERE id = $1;"; - Ok(sqlx::query(query) - .bind(id.0) - .execute(&self.pool) - .await - .map(drop)?) + sqlx::query(query).bind(id.0).execute(&mut tx).await?; + tx.commit().await?; + Ok(()) }) + .await } // signups pub async fn get_waitlist_summary(&self) -> Result { - test_support!(self, { + self.transact(|mut tx| async move { Ok(sqlx::query_as( " SELECT @@ -784,13 +816,14 @@ where ) AS unsent ", ) - .fetch_one(&self.pool) + .fetch_one(&mut tx) .await?) }) + .await } pub async fn get_unsent_invites(&self, count: usize) -> Result> { - test_support!(self, { + self.transact(|mut tx| async move { Ok(sqlx::query_as( " SELECT @@ -803,16 +836,16 @@ where ", ) .bind(count as i32) - .fetch_all(&self.pool) + .fetch_all(&mut tx) .await?) }) + .await } // invite codes pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> { - test_support!(self, { - let mut tx = self.pool.begin().await?; + self.transact(|mut tx| async move { if count > 0 { sqlx::query( " @@ -841,10 +874,11 @@ where tx.commit().await?; Ok(()) }) + .await } pub async fn get_invite_code_for_user(&self, id: UserId) -> Result> { - test_support!(self, { + self.transact(|mut tx| async move { let result: Option<(String, i32)> = sqlx::query_as( " SELECT invite_code, invite_count @@ -853,7 +887,7 @@ where ", ) .bind(id) - .fetch_optional(&self.pool) + .fetch_optional(&mut tx) .await?; if let Some((code, count)) = result { Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?))) @@ -861,10 +895,12 @@ where Ok(None) } }) + .await } pub async fn get_user_for_invite_code(&self, code: &str) -> Result { - test_support!(self, { + self.transact(|tx| async { + let mut tx = tx; sqlx::query_as( " SELECT * @@ -873,7 +909,7 @@ where ", ) .bind(code) - .fetch_optional(&self.pool) + .fetch_optional(&mut tx) .await? .ok_or_else(|| { Error::Http( @@ -882,6 +918,7 @@ where ) }) }) + .await } pub async fn create_room( @@ -889,8 +926,7 @@ where user_id: UserId, connection_id: ConnectionId, ) -> Result { - test_support!(self, { - let mut tx = self.pool.begin().await?; + self.transact(|mut tx| async move { let live_kit_room = nanoid::nanoid!(30); let room_id = sqlx::query_scalar( " @@ -920,7 +956,7 @@ where .await?; self.commit_room_transaction(room_id, tx).await - }) + }).await } pub async fn call( @@ -931,8 +967,7 @@ where called_user_id: UserId, initial_project_id: Option, ) -> Result<(proto::Room, proto::IncomingCall)> { - test_support!(self, { - let mut tx = self.pool.begin().await?; + self.transact(|mut tx| async move { sqlx::query( " INSERT INTO room_participants (room_id, user_id, calling_user_id, calling_connection_id, initial_project_id) @@ -951,15 +986,14 @@ where let incoming_call = Self::build_incoming_call(&room, called_user_id) .ok_or_else(|| anyhow!("failed to build incoming call"))?; Ok((room, incoming_call)) - }) + }).await } pub async fn incoming_call_for_user( &self, user_id: UserId, ) -> Result> { - test_support!(self, { - let mut tx = self.pool.begin().await?; + self.transact(|mut tx| async move { let room_id = sqlx::query_scalar::<_, RoomId>( " SELECT room_id @@ -978,6 +1012,7 @@ where Ok(None) } }) + .await } fn build_incoming_call( @@ -1013,8 +1048,7 @@ where room_id: RoomId, called_user_id: UserId, ) -> Result { - test_support!(self, { - let mut tx = self.pool.begin().await?; + self.transact(|mut tx| async move { sqlx::query( " DELETE FROM room_participants @@ -1028,6 +1062,7 @@ where self.commit_room_transaction(room_id, tx).await }) + .await } pub async fn decline_call( @@ -1035,8 +1070,7 @@ where expected_room_id: Option, user_id: UserId, ) -> Result { - test_support!(self, { - let mut tx = self.pool.begin().await?; + self.transact(|mut tx| async move { let room_id = sqlx::query_scalar( " DELETE FROM room_participants @@ -1053,6 +1087,7 @@ where self.commit_room_transaction(room_id, tx).await }) + .await } pub async fn cancel_call( @@ -1061,8 +1096,7 @@ where calling_connection_id: ConnectionId, called_user_id: UserId, ) -> Result { - test_support!(self, { - let mut tx = self.pool.begin().await?; + self.transact(|mut tx| async move { let room_id = sqlx::query_scalar( " DELETE FROM room_participants @@ -1079,7 +1113,7 @@ where } self.commit_room_transaction(room_id, tx).await - }) + }).await } pub async fn join_room( @@ -1088,8 +1122,7 @@ where user_id: UserId, connection_id: ConnectionId, ) -> Result { - test_support!(self, { - let mut tx = self.pool.begin().await?; + self.transact(|mut tx| async move { sqlx::query( " UPDATE room_participants @@ -1105,15 +1138,14 @@ where .await?; self.commit_room_transaction(room_id, tx).await }) + .await } pub async fn leave_room_for_connection( &self, connection_id: ConnectionId, ) -> Result> { - test_support!(self, { - let mut tx = self.pool.begin().await?; - + self.transact(|mut tx| async move { // Leave room. let room_id = sqlx::query_scalar::<_, RoomId>( " @@ -1198,6 +1230,7 @@ where Ok(None) } }) + .await } pub async fn update_room_participant_location( @@ -1206,13 +1239,13 @@ where connection_id: ConnectionId, location: proto::ParticipantLocation, ) -> Result { - test_support!(self, { - let mut tx = self.pool.begin().await?; - + self.transact(|tx| async { + let mut tx = tx; let location_kind; let location_project_id; match location .variant + .as_ref() .ok_or_else(|| anyhow!("invalid location"))? { proto::participant_location::Variant::SharedProject(project) => { @@ -1245,6 +1278,7 @@ where self.commit_room_transaction(room_id, tx).await }) + .await } async fn commit_room_transaction( @@ -1375,8 +1409,7 @@ where connection_id: ConnectionId, worktrees: &[proto::WorktreeMetadata], ) -> Result<(ProjectId, proto::Room)> { - test_support!(self, { - let mut tx = self.pool.begin().await?; + self.transact(|mut tx| async move { let project_id = sqlx::query_scalar( " INSERT INTO projects (room_id, host_user_id, host_connection_id) @@ -1428,16 +1461,65 @@ where let room = self.commit_room_transaction(room_id, tx).await?; Ok((project_id, room)) }) + .await } - // pub async fn join_project( - // &self, - // user_id: UserId, - // connection_id: ConnectionId, - // project_id: ProjectId, - // ) -> Result<(Project, ReplicaId)> { - // todo!() - // } + pub async fn update_project( + &self, + project_id: ProjectId, + connection_id: ConnectionId, + worktrees: &[proto::WorktreeMetadata], + ) -> Result<(proto::Room, Vec)> { + self.transact(|mut tx| async move { + let room_id: RoomId = sqlx::query_scalar( + " + SELECT room_id + FROM projects + WHERE id = $1 AND host_connection_id = $2 + ", + ) + .bind(project_id) + .bind(connection_id.0 as i32) + .fetch_one(&mut tx) + .await?; + + for worktree in worktrees { + sqlx::query( + " + INSERT INTO worktrees (project_id, id, root_name) + VALUES ($1, $2, $3) + ON CONFLICT (project_id, id) DO UPDATE SET root_name = excluded.root_name + ", + ) + .bind(project_id) + .bind(worktree.id as i32) + .bind(&worktree.root_name) + .execute(&mut tx) + .await?; + } + + let mut params = "?,".repeat(worktrees.len()); + if !worktrees.is_empty() { + params.pop(); + } + let query = format!( + " + DELETE FROM worktrees + WHERE id NOT IN ({params}) + ", + ); + + let mut query = sqlx::query(&query); + for worktree in worktrees { + query = query.bind(worktree.id as i32); + } + query.execute(&mut tx).await?; + + let room = self.commit_room_transaction(room_id, tx).await?; + todo!() + }) + .await + } pub async fn unshare_project(&self, project_id: ProjectId) -> Result<()> { todo!() @@ -1459,7 +1541,7 @@ where // contacts pub async fn get_contacts(&self, user_id: UserId) -> Result> { - test_support!(self, { + self.transact(|mut tx| async move { let query = " SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify FROM contacts @@ -1468,7 +1550,7 @@ where let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query) .bind(user_id) - .fetch(&self.pool); + .fetch(&mut tx); let mut contacts = Vec::new(); while let Some(row) = rows.next().await { @@ -1507,10 +1589,11 @@ where Ok(contacts) }) + .await } pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result { - test_support!(self, { + self.transact(|mut tx| async move { let (id_a, id_b) = if user_id_1 < user_id_2 { (user_id_1, user_id_2) } else { @@ -1525,14 +1608,15 @@ where Ok(sqlx::query_scalar::<_, i32>(query) .bind(id_a.0) .bind(id_b.0) - .fetch_optional(&self.pool) + .fetch_optional(&mut tx) .await? .is_some()) }) + .await } pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { - test_support!(self, { + self.transact(|mut tx| async move { let (id_a, id_b, a_to_b) = if sender_id < receiver_id { (sender_id, receiver_id, true) } else { @@ -1554,7 +1638,7 @@ where .bind(id_a.0) .bind(id_b.0) .bind(a_to_b) - .execute(&self.pool) + .execute(&mut tx) .await?; if result.rows_affected() == 1 { @@ -1562,11 +1646,11 @@ where } else { Err(anyhow!("contact already requested"))? } - }) + }).await } pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> { - test_support!(self, { + self.transact(|mut tx| async move { let (id_a, id_b) = if responder_id < requester_id { (responder_id, requester_id) } else { @@ -1579,7 +1663,7 @@ where let result = sqlx::query(query) .bind(id_a.0) .bind(id_b.0) - .execute(&self.pool) + .execute(&mut tx) .await?; if result.rows_affected() == 1 { @@ -1588,6 +1672,7 @@ where Err(anyhow!("no such contact"))? } }) + .await } pub async fn dismiss_contact_notification( @@ -1595,7 +1680,7 @@ where user_id: UserId, contact_user_id: UserId, ) -> Result<()> { - test_support!(self, { + self.transact(|mut tx| async move { let (id_a, id_b, a_to_b) = if user_id < contact_user_id { (user_id, contact_user_id, true) } else { @@ -1617,7 +1702,7 @@ where .bind(id_a.0) .bind(id_b.0) .bind(a_to_b) - .execute(&self.pool) + .execute(&mut tx) .await?; if result.rows_affected() == 0 { @@ -1626,6 +1711,7 @@ where Ok(()) }) + .await } pub async fn respond_to_contact_request( @@ -1634,7 +1720,7 @@ where requester_id: UserId, accept: bool, ) -> Result<()> { - test_support!(self, { + self.transact(|mut tx| async move { let (id_a, id_b, a_to_b) = if responder_id < requester_id { (responder_id, requester_id, false) } else { @@ -1650,7 +1736,7 @@ where .bind(id_a.0) .bind(id_b.0) .bind(a_to_b) - .execute(&self.pool) + .execute(&mut tx) .await? } else { let query = " @@ -1661,7 +1747,7 @@ where .bind(id_a.0) .bind(id_b.0) .bind(a_to_b) - .execute(&self.pool) + .execute(&mut tx) .await? }; if result.rows_affected() == 1 { @@ -1670,6 +1756,7 @@ where Err(anyhow!("no such contact request"))? } }) + .await } // access tokens @@ -1680,7 +1767,8 @@ where access_token_hash: &str, max_access_token_count: usize, ) -> Result<()> { - test_support!(self, { + self.transact(|tx| async { + let mut tx = tx; let insert_query = " INSERT INTO access_tokens (user_id, hash) VALUES ($1, $2); @@ -1696,7 +1784,6 @@ where ) "; - let mut tx = self.pool.begin().await?; sqlx::query(insert_query) .bind(user_id.0) .bind(access_token_hash) @@ -1710,10 +1797,11 @@ where .await?; Ok(tx.commit().await?) }) + .await } pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result> { - test_support!(self, { + self.transact(|mut tx| async move { let query = " SELECT hash FROM access_tokens @@ -1722,9 +1810,51 @@ where "; Ok(sqlx::query_scalar(query) .bind(user_id.0) - .fetch_all(&self.pool) + .fetch_all(&mut tx) .await?) }) + .await + } + + async fn transact(&self, f: F) -> Result + where + F: Send + Fn(sqlx::Transaction<'static, D>) -> Fut, + Fut: Send + Future>, + { + let body = async { + loop { + let tx = self.begin_transaction().await?; + match f(tx).await { + Ok(result) => return Ok(result), + Err(error) => match error { + Error::Database(error) + if error + .as_database_error() + .and_then(|error| error.code()) + .as_deref() + == Some("hey") => + { + // Retry (don't break the loop) + } + error @ _ => return Err(error), + }, + } + } + }; + + #[cfg(test)] + { + if let Some(background) = self.background.as_ref() { + background.simulate_random_delay().await; + } + + self.runtime.as_ref().unwrap().block_on(body) + } + + #[cfg(not(test))] + { + body.await + } } } diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 518530c539bdcffae03059732e5e9dba4401ac56..be21999a4567f385143bfeaba05101a7cd185ce5 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -4,6 +4,7 @@ pub type Result = std::result::Result; pub enum Error { Http(StatusCode, String), + Database(sqlx::Error), Internal(anyhow::Error), } @@ -15,7 +16,7 @@ impl From for Error { impl From for Error { fn from(error: sqlx::Error) -> Self { - Self::Internal(error.into()) + Self::Database(error) } } @@ -41,6 +42,9 @@ impl IntoResponse for Error { fn into_response(self) -> axum::response::Response { match self { Error::Http(code, message) => (code, message).into_response(), + Error::Database(error) => { + (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() + } Error::Internal(error) => { (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() } @@ -52,6 +56,7 @@ impl std::fmt::Debug for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Error::Http(code, message) => (code, message).fmt(f), + Error::Database(error) => error.fmt(f), Error::Internal(error) => error.fmt(f), } } @@ -61,6 +66,7 @@ impl std::fmt::Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Error::Http(code, message) => write!(f, "{code}: {message}"), + Error::Database(error) => error.fmt(f), Error::Internal(error) => error.fmt(f), } } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index bed6ebf9cd649a10bb9bce2d931606e2a56ed281..d8ca51e6cd7b92513264f187ff7b99c16e31c340 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1087,30 +1087,31 @@ impl Server { response: Response, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); - { - let mut state = self.store().await; - let guest_connection_ids = state - .read_project(project_id, request.sender_connection_id)? - .guest_connection_ids(); - let room = state.update_project( + let (room, guest_connection_ids) = self + .app_state + .db + .update_project( project_id, - &request.payload.worktrees, request.sender_connection_id, - )?; - broadcast( - request.sender_connection_id, - guest_connection_ids, - |connection_id| { - self.peer.forward_send( - request.sender_connection_id, - connection_id, - request.payload.clone(), - ) - }, - ); - self.room_updated(room); - response.send(proto::Ack {})?; - }; + &request.payload.worktrees, + ) + .await?; + broadcast( + request.sender_connection_id, + guest_connection_ids, + |connection_id| { + self.peer.send( + connection_id, + proto::ProjectUpdated { + project_id: project_id.to_proto(), + worktrees: request.payload.worktrees.clone(), + room_version: room.version, + }, + ) + }, + ); + self.room_updated(&room); + response.send(proto::Ack {})?; Ok(()) } From 42bb5f0e9f7552a861a76a4cfda02462536aba89 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 15 Nov 2022 08:48:16 +0100 Subject: [PATCH 019/109] Add random delay after returning results from the database --- crates/collab/src/db.rs | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index b561ba045d1344bdce46f888df0614639a82dd8d..fb91e92808e81d0ccc4af7151069ece8ee5cb31d 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1848,7 +1848,13 @@ where background.simulate_random_delay().await; } - self.runtime.as_ref().unwrap().block_on(body) + let result = self.runtime.as_ref().unwrap().block_on(body); + + if let Some(background) = self.background.as_ref() { + background.simulate_random_delay().await; + } + + result } #[cfg(not(test))] From 3e8fcb04f71f877a641f19f884f1d4f8cc3da188 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 15 Nov 2022 09:00:56 +0100 Subject: [PATCH 020/109] Finish implementing `Db::update_project` --- crates/collab/src/db.rs | 17 +++++++++++- crates/collab/src/rpc.rs | 9 +++---- crates/collab/src/rpc/store.rs | 49 ---------------------------------- crates/project/src/project.rs | 6 ++--- crates/rpc/proto/zed.proto | 7 ----- crates/rpc/src/proto.rs | 2 -- 6 files changed, 22 insertions(+), 68 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index fb91e92808e81d0ccc4af7151069ece8ee5cb31d..ba014624af4845a950cb5a94f14b579fe022ad87 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1515,8 +1515,23 @@ where } query.execute(&mut tx).await?; + let mut guest_connection_ids = Vec::new(); + { + let mut db_guest_connection_ids = sqlx::query_scalar::<_, i32>( + " + SELECT connection_id + FROM project_collaborators + WHERE project_id = $1 AND is_host = FALSE + ", + ) + .fetch(&mut tx); + while let Some(connection_id) = db_guest_connection_ids.next().await { + guest_connection_ids.push(ConnectionId(connection_id? as u32)); + } + } + let room = self.commit_room_transaction(room_id, tx).await?; - todo!() + Ok((room, guest_connection_ids)) }) .await } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index d8ca51e6cd7b92513264f187ff7b99c16e31c340..daf898ddf6263c51c15a2ad6345afa0f6fe4f96a 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1100,13 +1100,10 @@ impl Server { request.sender_connection_id, guest_connection_ids, |connection_id| { - self.peer.send( + self.peer.forward_send( + request.sender_connection_id, connection_id, - proto::ProjectUpdated { - project_id: project_id.to_proto(), - worktrees: request.payload.worktrees.clone(), - room_version: room.version, - }, + request.payload.clone(), ) }, ); diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index a9793e9fb67af8e97c11d79d59a3b7927d24d3cd..a9a15e7b2aa775b4dba3de5a7c64d6623e4b9489 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -253,55 +253,6 @@ impl Store { } } - pub fn update_project( - &mut self, - project_id: ProjectId, - worktrees: &[proto::WorktreeMetadata], - connection_id: ConnectionId, - ) -> Result<&proto::Room> { - let project = self - .projects - .get_mut(&project_id) - .ok_or_else(|| anyhow!("no such project"))?; - if project.host_connection_id == connection_id { - let mut old_worktrees = mem::take(&mut project.worktrees); - for worktree in worktrees { - if let Some(old_worktree) = old_worktrees.remove(&worktree.id) { - project.worktrees.insert(worktree.id, old_worktree); - } else { - project.worktrees.insert( - worktree.id, - Worktree { - root_name: worktree.root_name.clone(), - visible: worktree.visible, - ..Default::default() - }, - ); - } - } - - let room = self - .rooms - .get_mut(&project.room_id) - .ok_or_else(|| anyhow!("no such room"))?; - let participant_project = room - .participants - .iter_mut() - .flat_map(|participant| &mut participant.projects) - .find(|project| project.id == project_id.to_proto()) - .ok_or_else(|| anyhow!("no such project"))?; - participant_project.worktree_root_names = worktrees - .iter() - .filter(|worktree| worktree.visible) - .map(|worktree| worktree.root_name.clone()) - .collect(); - - Ok(room) - } else { - Err(anyhow!("no such project"))? - } - } - pub fn update_diagnostic_summary( &mut self, project_id: ProjectId, diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index d01571f44b1f7df78698420af0f19e282a4d8c55..c59b19de8fe2774a3d9b1c6b80a529e40d850c3b 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -373,7 +373,7 @@ impl Project { client.add_model_message_handler(Self::handle_start_language_server); client.add_model_message_handler(Self::handle_update_language_server); client.add_model_message_handler(Self::handle_remove_collaborator); - client.add_model_message_handler(Self::handle_project_updated); + client.add_model_message_handler(Self::handle_update_project); client.add_model_message_handler(Self::handle_unshare_project); client.add_model_message_handler(Self::handle_create_buffer_for_peer); client.add_model_message_handler(Self::handle_update_buffer_file); @@ -4533,9 +4533,9 @@ impl Project { }) } - async fn handle_project_updated( + async fn handle_update_project( this: ModelHandle, - envelope: TypedEnvelope, + envelope: TypedEnvelope, client: Arc, mut cx: AsyncAppContext, ) -> Result<()> { diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index 94880ce9f56e80e5b677eafa65878295a5b424e7..e688cad1f8e01a6c1926712438a16c85927b5d60 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -48,7 +48,6 @@ message Envelope { OpenBufferForSymbolResponse open_buffer_for_symbol_response = 40; UpdateProject update_project = 41; - ProjectUpdated project_updated = 42; UpdateWorktree update_worktree = 43; CreateProjectEntry create_project_entry = 45; @@ -257,12 +256,6 @@ message UpdateProject { repeated WorktreeMetadata worktrees = 2; } -message ProjectUpdated { - uint64 project_id = 1; - repeated WorktreeMetadata worktrees = 2; - uint64 room_version = 3; -} - message JoinProject { uint64 project_id = 1; } diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index 31f53564a8b9d99c7bba00de3de969f95cfc1498..6d9bc9a0aa348af8c1a14f442323fcf06064688e 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -144,7 +144,6 @@ messages!( (PrepareRename, Background), (PrepareRenameResponse, Background), (ProjectEntryResponse, Foreground), - (ProjectUpdated, Foreground), (RemoveContact, Foreground), (ReloadBuffers, Foreground), (ReloadBuffersResponse, Foreground), @@ -261,7 +260,6 @@ entity_messages!( OpenBufferByPath, OpenBufferForSymbol, PerformRename, - ProjectUpdated, PrepareRename, ReloadBuffers, RemoveProjectCollaborator, From 6cbf19722620c6836226a36ba5c6107d2f6d64d5 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 15 Nov 2022 10:41:21 +0100 Subject: [PATCH 021/109] Determine whether a contact is busy via the database --- .../20221109000000_test_schema.sql | 2 +- .../20221111092550_reconnection_support.sql | 1 + crates/collab/src/db.rs | 38 ++++++++++++--- crates/collab/src/db_tests.rs | 46 +++++++++++++------ crates/collab/src/rpc.rs | 10 ++-- crates/collab/src/rpc/store.rs | 22 ++++----- 6 files changed, 81 insertions(+), 38 deletions(-) diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index d262d6a8bd414a40cc71cc56690b3232e8eaaa81..d6759fb5246cfe6653db215cfc5ffe7f733f5d8b 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -56,7 +56,7 @@ CREATE TABLE "project_collaborators" ( "is_host" BOOLEAN NOT NULL ); CREATE INDEX "index_project_collaborators_on_project_id" ON "project_collaborators" ("project_id"); -CREATE UNIQUE INDEX "index_project_collaborators_on_project_id" ON "project_collaborators" ("project_id", "replica_id"); +CREATE UNIQUE INDEX "index_project_collaborators_on_project_id_and_replica_id" ON "project_collaborators" ("project_id", "replica_id"); CREATE TABLE "worktrees" ( "id" INTEGER NOT NULL, diff --git a/crates/collab/migrations/20221111092550_reconnection_support.sql b/crates/collab/migrations/20221111092550_reconnection_support.sql index 7b82ce9ce7f49ec953a2c8ef54e2cdbfe07d3274..617e282a0a479ecefc4d9a7339397c7a2b3c32d0 100644 --- a/crates/collab/migrations/20221111092550_reconnection_support.sql +++ b/crates/collab/migrations/20221111092550_reconnection_support.sql @@ -18,6 +18,7 @@ CREATE TABLE "project_collaborators" ( "is_host" BOOLEAN NOT NULL ); CREATE INDEX "index_project_collaborators_on_project_id" ON "project_collaborators" ("project_id"); +CREATE UNIQUE INDEX "index_project_collaborators_on_project_id_and_replica_id" ON "project_collaborators" ("project_id", "replica_id"); CREATE TABLE IF NOT EXISTS "worktrees" ( "id" INTEGER NOT NULL, diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index ba014624af4845a950cb5a94f14b579fe022ad87..1df96870d6bc0b3fb1b69cc08fcde073fcf34e36 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1558,24 +1558,25 @@ where pub async fn get_contacts(&self, user_id: UserId) -> Result> { self.transact(|mut tx| async move { let query = " - SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify + SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify, (room_participants.id IS NOT NULL) as busy FROM contacts + LEFT JOIN room_participants ON room_participants.user_id = $1 WHERE user_id_a = $1 OR user_id_b = $1; "; - let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query) + let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool, bool)>(query) .bind(user_id) .fetch(&mut tx); let mut contacts = Vec::new(); while let Some(row) = rows.next().await { - let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?; - + let (user_id_a, user_id_b, a_to_b, accepted, should_notify, busy) = row?; if user_id_a == user_id { if accepted { contacts.push(Contact::Accepted { user_id: user_id_b, should_notify: should_notify && a_to_b, + busy }); } else if a_to_b { contacts.push(Contact::Outgoing { user_id: user_id_b }) @@ -1589,6 +1590,7 @@ where contacts.push(Contact::Accepted { user_id: user_id_a, should_notify: should_notify && !a_to_b, + busy }); } else if a_to_b { contacts.push(Contact::Incoming { @@ -1607,6 +1609,23 @@ where .await } + pub async fn is_user_busy(&self, user_id: UserId) -> Result { + self.transact(|mut tx| async move { + Ok(sqlx::query_scalar::<_, i32>( + " + SELECT 1 + FROM room_participants + WHERE room_participants.user_id = $1 + ", + ) + .bind(user_id) + .fetch_optional(&mut tx) + .await? + .is_some()) + }) + .await + } + pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result { self.transact(|mut tx| async move { let (id_a, id_b) = if user_id_1 < user_id_2 { @@ -1657,6 +1676,7 @@ where .await?; if result.rows_affected() == 1 { + tx.commit().await?; Ok(()) } else { Err(anyhow!("contact already requested"))? @@ -1682,6 +1702,7 @@ where .await?; if result.rows_affected() == 1 { + tx.commit().await?; Ok(()) } else { Err(anyhow!("no such contact"))? @@ -1721,10 +1742,11 @@ where .await?; if result.rows_affected() == 0 { - Err(anyhow!("no such contact request"))?; + Err(anyhow!("no such contact request"))? + } else { + tx.commit().await?; + Ok(()) } - - Ok(()) }) .await } @@ -1766,6 +1788,7 @@ where .await? }; if result.rows_affected() == 1 { + tx.commit().await?; Ok(()) } else { Err(anyhow!("no such contact request"))? @@ -1977,6 +2000,7 @@ pub enum Contact { Accepted { user_id: UserId, should_notify: bool, + busy: bool, }, Outgoing { user_id: UserId, diff --git a/crates/collab/src/db_tests.rs b/crates/collab/src/db_tests.rs index 8eda7d34e298c975e53140c9ce3a7aed1551b706..444e60ddeb0c5e03df39e132189eac9ecca46033 100644 --- a/crates/collab/src/db_tests.rs +++ b/crates/collab/src/db_tests.rs @@ -258,7 +258,8 @@ test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { db.get_contacts(user_1).await.unwrap(), &[Contact::Accepted { user_id: user_2, - should_notify: true + should_notify: true, + busy: false, }], ); assert!(db.has_contact(user_1, user_2).await.unwrap()); @@ -268,6 +269,7 @@ test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { &[Contact::Accepted { user_id: user_1, should_notify: false, + busy: false, }] ); @@ -284,6 +286,7 @@ test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { &[Contact::Accepted { user_id: user_2, should_notify: true, + busy: false, }] ); @@ -296,6 +299,7 @@ test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { &[Contact::Accepted { user_id: user_2, should_notify: false, + busy: false, }] ); @@ -309,10 +313,12 @@ test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { Contact::Accepted { user_id: user_2, should_notify: false, + busy: false, }, Contact::Accepted { user_id: user_3, - should_notify: false + should_notify: false, + busy: false, } ] ); @@ -320,7 +326,8 @@ test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { db.get_contacts(user_3).await.unwrap(), &[Contact::Accepted { user_id: user_1, - should_notify: false + should_notify: false, + busy: false, }], ); @@ -335,14 +342,16 @@ test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { db.get_contacts(user_2).await.unwrap(), &[Contact::Accepted { user_id: user_1, - should_notify: false + should_notify: false, + busy: false, }] ); assert_eq!( db.get_contacts(user_3).await.unwrap(), &[Contact::Accepted { user_id: user_1, - should_notify: false + should_notify: false, + busy: false, }], ); }); @@ -504,14 +513,16 @@ async fn test_invite_codes() { db.get_contacts(user1).await.unwrap(), [Contact::Accepted { user_id: user2, - should_notify: true + should_notify: true, + busy: false, }] ); assert_eq!( db.get_contacts(user2).await.unwrap(), [Contact::Accepted { user_id: user1, - should_notify: false + should_notify: false, + busy: false, }] ); assert_eq!( @@ -550,11 +561,13 @@ async fn test_invite_codes() { [ Contact::Accepted { user_id: user2, - should_notify: true + should_notify: true, + busy: false, }, Contact::Accepted { user_id: user3, - should_notify: true + should_notify: true, + busy: false, } ] ); @@ -562,7 +575,8 @@ async fn test_invite_codes() { db.get_contacts(user3).await.unwrap(), [Contact::Accepted { user_id: user1, - should_notify: false + should_notify: false, + busy: false, }] ); assert_eq!( @@ -607,15 +621,18 @@ async fn test_invite_codes() { [ Contact::Accepted { user_id: user2, - should_notify: true + should_notify: true, + busy: false, }, Contact::Accepted { user_id: user3, - should_notify: true + should_notify: true, + busy: false, }, Contact::Accepted { user_id: user4, - should_notify: true + should_notify: true, + busy: false, } ] ); @@ -623,7 +640,8 @@ async fn test_invite_codes() { db.get_contacts(user4).await.unwrap(), [Contact::Accepted { user_id: user1, - should_notify: false + should_notify: false, + busy: false, }] ); assert_eq!( diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index daf898ddf6263c51c15a2ad6345afa0f6fe4f96a..627a22426a76f30964d61d364a14529154498606 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -465,7 +465,7 @@ impl Server { if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? { if let Some(code) = &user.invite_code { let store = self.store().await; - let invitee_contact = store.contact_for_user(invitee_id, true); + let invitee_contact = store.contact_for_user(invitee_id, true, false); for connection_id in store.connection_ids_for_user(inviter_id) { self.peer.send( connection_id, @@ -895,8 +895,9 @@ impl Server { async fn update_user_contacts(self: &Arc, user_id: UserId) -> Result<()> { let contacts = self.app_state.db.get_contacts(user_id).await?; + let busy = self.app_state.db.is_user_busy(user_id).await?; let store = self.store().await; - let updated_contact = store.contact_for_user(user_id, false); + let updated_contact = store.contact_for_user(user_id, false, busy); for contact in contacts { if let db::Contact::Accepted { user_id: contact_user_id, @@ -1575,6 +1576,7 @@ impl Server { .db .respond_to_contact_request(responder_id, requester_id, accept) .await?; + let busy = self.app_state.db.is_user_busy(requester_id).await?; let store = self.store().await; // Update responder with new contact @@ -1582,7 +1584,7 @@ impl Server { if accept { update .contacts - .push(store.contact_for_user(requester_id, false)); + .push(store.contact_for_user(requester_id, false, busy)); } update .remove_incoming_requests @@ -1596,7 +1598,7 @@ impl Server { if accept { update .contacts - .push(store.contact_for_user(responder_id, true)); + .push(store.contact_for_user(responder_id, true, busy)); } update .remove_outgoing_requests diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index a9a15e7b2aa775b4dba3de5a7c64d6623e4b9489..4be93547889683d75a7439fb98673ad4532e308a 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -3,7 +3,7 @@ use anyhow::{anyhow, Result}; use collections::{btree_map, BTreeMap, BTreeSet, HashMap, HashSet}; use rpc::{proto, ConnectionId}; use serde::Serialize; -use std::{mem, path::PathBuf, str}; +use std::{path::PathBuf, str}; use tracing::instrument; pub type RoomId = u64; @@ -156,14 +156,6 @@ impl Store { .is_empty() } - fn is_user_busy(&self, user_id: UserId) -> bool { - self.connected_users - .get(&user_id) - .unwrap_or(&Default::default()) - .active_call - .is_some() - } - pub fn build_initial_contacts_update( &self, contacts: Vec, @@ -175,10 +167,11 @@ impl Store { db::Contact::Accepted { user_id, should_notify, + busy, } => { update .contacts - .push(self.contact_for_user(user_id, should_notify)); + .push(self.contact_for_user(user_id, should_notify, busy)); } db::Contact::Outgoing { user_id } => { update.outgoing_requests.push(user_id.to_proto()) @@ -198,11 +191,16 @@ impl Store { update } - pub fn contact_for_user(&self, user_id: UserId, should_notify: bool) -> proto::Contact { + pub fn contact_for_user( + &self, + user_id: UserId, + should_notify: bool, + busy: bool, + ) -> proto::Contact { proto::Contact { user_id: user_id.to_proto(), online: self.is_user_online(user_id), - busy: self.is_user_busy(user_id), + busy, should_notify, } } From be523617c98bbe63d5bf002fa4dd7e12872afdf7 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 15 Nov 2022 11:44:26 +0100 Subject: [PATCH 022/109] Start reworking `join_project` to use the database --- .../20221109000000_test_schema.sql | 3 + crates/collab/src/db.rs | 152 ++++++++++++++++-- crates/collab/src/rpc.rs | 43 +++-- 3 files changed, 164 insertions(+), 34 deletions(-) diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index d6759fb5246cfe6653db215cfc5ffe7f733f5d8b..1a09dff7807e02fc1ea98548f4c1095316620873 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -62,6 +62,9 @@ CREATE TABLE "worktrees" ( "id" INTEGER NOT NULL, "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, "root_name" VARCHAR NOT NULL, + "visible" BOOL NOT NULL, + "scan_id" INTEGER NOT NULL, + "is_complete" BOOL NOT NULL, PRIMARY KEY(project_id, id) ); CREATE INDEX "index_worktrees_on_project_id" ON "worktrees" ("project_id"); diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 1df96870d6bc0b3fb1b69cc08fcde073fcf34e36..88b6f20953a9c3019bb2771127832de3f1d85eb9 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1,7 +1,7 @@ use crate::{Error, Result}; use anyhow::anyhow; use axum::http::StatusCode; -use collections::HashMap; +use collections::{BTreeMap, HashMap, HashSet}; use futures::{future::BoxFuture, FutureExt, StreamExt}; use rpc::{proto, ConnectionId}; use serde::{Deserialize, Serialize}; @@ -10,7 +10,11 @@ use sqlx::{ types::Uuid, FromRow, }; -use std::{future::Future, path::Path, time::Duration}; +use std::{ + future::Future, + path::{Path, PathBuf}, + time::Duration, +}; use time::{OffsetDateTime, PrimitiveDateTime}; #[cfg(test)] @@ -1404,13 +1408,26 @@ where pub async fn share_project( &self, - room_id: RoomId, - user_id: UserId, + expected_room_id: RoomId, connection_id: ConnectionId, worktrees: &[proto::WorktreeMetadata], ) -> Result<(ProjectId, proto::Room)> { self.transact(|mut tx| async move { - let project_id = sqlx::query_scalar( + let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>( + " + SELECT room_id, user_id + FROM room_participants + WHERE answering_connection_id = $1 + ", + ) + .bind(connection_id.0 as i32) + .fetch_one(&mut tx) + .await?; + if room_id != expected_room_id { + return Err(anyhow!("shared project on unexpected room"))?; + } + + let project_id: ProjectId = sqlx::query_scalar( " INSERT INTO projects (room_id, host_user_id, host_connection_id) VALUES ($1, $2, $3) @@ -1421,8 +1438,7 @@ where .bind(user_id) .bind(connection_id.0 as i32) .fetch_one(&mut tx) - .await - .map(ProjectId)?; + .await?; for worktree in worktrees { sqlx::query( @@ -1536,6 +1552,111 @@ where .await } + pub async fn join_project( + &self, + project_id: ProjectId, + connection_id: ConnectionId, + ) -> Result<(Project, i32)> { + self.transact(|mut tx| async move { + let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>( + " + SELECT room_id, user_id + FROM room_participants + WHERE answering_connection_id = $1 + ", + ) + .bind(connection_id.0 as i32) + .fetch_one(&mut tx) + .await?; + + // Ensure project id was shared on this room. + sqlx::query( + " + SELECT 1 + FROM projects + WHERE project_id = $1 AND room_id = $2 + ", + ) + .bind(project_id) + .bind(room_id) + .fetch_one(&mut tx) + .await?; + + let replica_ids = sqlx::query_scalar::<_, i32>( + " + SELECT replica_id + FROM project_collaborators + WHERE project_id = $1 + ", + ) + .bind(project_id) + .fetch_all(&mut tx) + .await?; + let replica_ids = HashSet::from_iter(replica_ids); + let mut replica_id = 1; + while replica_ids.contains(&replica_id) { + replica_id += 1; + } + + sqlx::query( + " + INSERT INTO project_collaborators ( + project_id, + connection_id, + user_id, + replica_id, + is_host + ) + VALUES ($1, $2, $3, $4, $5) + ", + ) + .bind(project_id) + .bind(connection_id.0 as i32) + .bind(user_id) + .bind(replica_id) + .bind(false) + .execute(&mut tx) + .await?; + + tx.commit().await?; + todo!() + }) + .await + // sqlx::query( + // " + // SELECT replica_id + // FROM project_collaborators + // WHERE project_id = $ + // ", + // ) + // .bind(project_id) + // .bind(connection_id.0 as i32) + // .bind(user_id) + // .bind(0) + // .bind(true) + // .execute(&mut tx) + // .await?; + // sqlx::query( + // " + // INSERT INTO project_collaborators ( + // project_id, + // connection_id, + // user_id, + // replica_id, + // is_host + // ) + // VALUES ($1, $2, $3, $4, $5) + // ", + // ) + // .bind(project_id) + // .bind(connection_id.0 as i32) + // .bind(user_id) + // .bind(0) + // .bind(true) + // .execute(&mut tx) + // .await?; + } + pub async fn unshare_project(&self, project_id: ProjectId) -> Result<()> { todo!() // test_support!(self, { @@ -1967,11 +2088,11 @@ pub struct Room { } id_type!(ProjectId); -#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)] pub struct Project { pub id: ProjectId, - pub host_user_id: UserId, - pub unregistered: bool, + pub collaborators: Vec, + pub worktrees: BTreeMap, + pub language_servers: Vec, } #[derive(Clone, Debug, Default, FromRow, PartialEq)] @@ -1983,6 +2104,17 @@ pub struct ProjectCollaborator { pub is_host: bool, } +#[derive(Default)] +pub struct Worktree { + pub abs_path: PathBuf, + pub root_name: String, + pub visible: bool, + pub entries: BTreeMap, + pub diagnostic_summaries: BTreeMap, + pub scan_id: u64, + pub is_complete: bool, +} + pub struct LeftProject { pub id: ProjectId, pub host_user_id: UserId, diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 627a22426a76f30964d61d364a14529154498606..02d8f25f38af2464ba63076da6cb11ed6ee28225 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -862,7 +862,6 @@ impl Server { .db .share_project( RoomId::from_proto(request.payload.room_id), - request.sender_user_id, request.sender_connection_id, &request.payload.worktrees, ) @@ -942,15 +941,21 @@ impl Server { tracing::info!(%project_id, %host_user_id, %host_connection_id, "join project"); - let mut store = self.store().await; - let (project, replica_id) = store.join_project(request.sender_connection_id, project_id)?; - let peer_count = project.guests.len(); - let mut collaborators = Vec::with_capacity(peer_count); - collaborators.push(proto::Collaborator { - peer_id: project.host_connection_id.0, - replica_id: 0, - user_id: project.host.user_id.to_proto(), - }); + let (project, replica_id) = self + .app_state + .db + .join_project(project_id, request.sender_connection_id) + .await?; + + let collaborators = project + .collaborators + .iter() + .map(|collaborator| proto::Collaborator { + peer_id: collaborator.connection_id as u32, + replica_id: collaborator.replica_id as u32, + user_id: collaborator.user_id.to_proto(), + }) + .collect::>(); let worktrees = project .worktrees .iter() @@ -962,22 +967,12 @@ impl Server { }) .collect::>(); - // Add all guests other than the requesting user's own connections as collaborators - for (guest_conn_id, guest) in &project.guests { - if request.sender_connection_id != *guest_conn_id { - collaborators.push(proto::Collaborator { - peer_id: guest_conn_id.0, - replica_id: guest.replica_id as u32, - user_id: guest.user_id.to_proto(), - }); - } - } - - for conn_id in project.connection_ids() { - if conn_id != request.sender_connection_id { + for collaborator in &project.collaborators { + let connection_id = ConnectionId(collaborator.connection_id as u32); + if connection_id != request.sender_connection_id { self.peer .send( - conn_id, + connection_id, proto::AddProjectCollaborator { project_id: project_id.to_proto(), collaborator: Some(proto::Collaborator { From 974ef967a313868b49f70fe1ea5491adcd9b276d Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 15 Nov 2022 16:37:32 +0100 Subject: [PATCH 023/109] Move `Store::join_project` to `Db::join_project` Co-Authored-By: Nathan Sobo --- .../20221109000000_test_schema.sql | 59 +++- .../20221111092550_reconnection_support.sql | 42 ++- crates/collab/src/db.rs | 275 +++++++++++++----- crates/collab/src/integration_tests.rs | 8 +- crates/collab/src/rpc.rs | 37 +-- crates/collab/src/rpc/store.rs | 49 ---- crates/rpc/proto/zed.proto | 7 - 7 files changed, 311 insertions(+), 166 deletions(-) diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 1a09dff7807e02fc1ea98548f4c1095316620873..cffb549a891cb97e83bf16a428d4b8a9a57669d1 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -47,21 +47,11 @@ CREATE TABLE "projects" ( "host_connection_id" INTEGER NOT NULL ); -CREATE TABLE "project_collaborators" ( - "id" INTEGER PRIMARY KEY, - "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, - "connection_id" INTEGER NOT NULL, - "user_id" INTEGER NOT NULL, - "replica_id" INTEGER NOT NULL, - "is_host" BOOLEAN NOT NULL -); -CREATE INDEX "index_project_collaborators_on_project_id" ON "project_collaborators" ("project_id"); -CREATE UNIQUE INDEX "index_project_collaborators_on_project_id_and_replica_id" ON "project_collaborators" ("project_id", "replica_id"); - CREATE TABLE "worktrees" ( "id" INTEGER NOT NULL, - "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, + "project_id" INTEGER NOT NULL REFERENCES projects (id), "root_name" VARCHAR NOT NULL, + "abs_path" VARCHAR NOT NULL, "visible" BOOL NOT NULL, "scan_id" INTEGER NOT NULL, "is_complete" BOOL NOT NULL, @@ -69,6 +59,51 @@ CREATE TABLE "worktrees" ( ); CREATE INDEX "index_worktrees_on_project_id" ON "worktrees" ("project_id"); +CREATE TABLE "worktree_entries" ( + "id" INTEGER NOT NULL, + "project_id" INTEGER NOT NULL REFERENCES projects (id), + "worktree_id" INTEGER NOT NULL REFERENCES worktrees (id), + "is_dir" BOOL NOT NULL, + "path" VARCHAR NOT NULL, + "inode" INTEGER NOT NULL, + "mtime_seconds" INTEGER NOT NULL, + "mtime_nanos" INTEGER NOT NULL, + "is_symlink" BOOL NOT NULL, + "is_ignored" BOOL NOT NULL, + PRIMARY KEY(project_id, worktree_id, id) +); +CREATE INDEX "index_worktree_entries_on_project_id_and_worktree_id" ON "worktree_entries" ("project_id", "worktree_id"); + +CREATE TABLE "worktree_diagnostic_summaries" ( + "path" VARCHAR NOT NULL, + "project_id" INTEGER NOT NULL REFERENCES projects (id), + "worktree_id" INTEGER NOT NULL REFERENCES worktrees (id), + "language_server_id" INTEGER NOT NULL, + "error_count" INTEGER NOT NULL, + "warning_count" INTEGER NOT NULL, + PRIMARY KEY(project_id, worktree_id, path) +); +CREATE INDEX "index_worktree_diagnostic_summaries_on_project_id_and_worktree_id" ON "worktree_diagnostic_summaries" ("project_id", "worktree_id"); + +CREATE TABLE "language_servers" ( + "id" INTEGER NOT NULL, + "project_id" INTEGER NOT NULL REFERENCES projects (id), + "name" VARCHAR NOT NULL, + PRIMARY KEY(project_id, id) +); +CREATE INDEX "index_language_servers_on_project_id" ON "language_servers" ("project_id"); + +CREATE TABLE "project_collaborators" ( + "id" INTEGER PRIMARY KEY, + "project_id" INTEGER NOT NULL REFERENCES projects (id), + "connection_id" INTEGER NOT NULL, + "user_id" INTEGER NOT NULL, + "replica_id" INTEGER NOT NULL, + "is_host" BOOLEAN NOT NULL +); +CREATE INDEX "index_project_collaborators_on_project_id" ON "project_collaborators" ("project_id"); +CREATE UNIQUE INDEX "index_project_collaborators_on_project_id_and_replica_id" ON "project_collaborators" ("project_id", "replica_id"); + CREATE TABLE "room_participants" ( "id" INTEGER PRIMARY KEY, "room_id" INTEGER NOT NULL REFERENCES rooms (id), diff --git a/crates/collab/migrations/20221111092550_reconnection_support.sql b/crates/collab/migrations/20221111092550_reconnection_support.sql index 617e282a0a479ecefc4d9a7339397c7a2b3c32d0..a5b49ad7636ef5e4aa398a31d199bf7e49bc5dd4 100644 --- a/crates/collab/migrations/20221111092550_reconnection_support.sql +++ b/crates/collab/migrations/20221111092550_reconnection_support.sql @@ -20,14 +20,52 @@ CREATE TABLE "project_collaborators" ( CREATE INDEX "index_project_collaborators_on_project_id" ON "project_collaborators" ("project_id"); CREATE UNIQUE INDEX "index_project_collaborators_on_project_id_and_replica_id" ON "project_collaborators" ("project_id", "replica_id"); -CREATE TABLE IF NOT EXISTS "worktrees" ( +CREATE TABLE "worktrees" ( "id" INTEGER NOT NULL, - "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, + "project_id" INTEGER NOT NULL REFERENCES projects (id), "root_name" VARCHAR NOT NULL, + "abs_path" VARCHAR NOT NULL, + "visible" BOOL NOT NULL, + "scan_id" INTEGER NOT NULL, + "is_complete" BOOL NOT NULL, PRIMARY KEY(project_id, id) ); CREATE INDEX "index_worktrees_on_project_id" ON "worktrees" ("project_id"); +CREATE TABLE "worktree_entries" ( + "id" INTEGER NOT NULL, + "project_id" INTEGER NOT NULL REFERENCES projects (id), + "worktree_id" INTEGER NOT NULL REFERENCES worktrees (id), + "is_dir" BOOL NOT NULL, + "path" VARCHAR NOT NULL, + "inode" INTEGER NOT NULL, + "mtime_seconds" INTEGER NOT NULL, + "mtime_nanos" INTEGER NOT NULL, + "is_symlink" BOOL NOT NULL, + "is_ignored" BOOL NOT NULL, + PRIMARY KEY(project_id, worktree_id, id) +); +CREATE INDEX "index_worktree_entries_on_project_id_and_worktree_id" ON "worktree_entries" ("project_id", "worktree_id"); + +CREATE TABLE "worktree_diagnostic_summaries" ( + "path" VARCHAR NOT NULL, + "project_id" INTEGER NOT NULL REFERENCES projects (id), + "worktree_id" INTEGER NOT NULL REFERENCES worktrees (id), + "language_server_id" INTEGER NOT NULL, + "error_count" INTEGER NOT NULL, + "warning_count" INTEGER NOT NULL, + PRIMARY KEY(project_id, worktree_id, path) +); +CREATE INDEX "index_worktree_diagnostic_summaries_on_project_id_and_worktree_id" ON "worktree_diagnostic_summaries" ("project_id", "worktree_id"); + +CREATE TABLE "language_servers" ( + "id" INTEGER NOT NULL, + "project_id" INTEGER NOT NULL REFERENCES projects (id), + "name" VARCHAR NOT NULL, + PRIMARY KEY(project_id, id) +); +CREATE INDEX "index_language_servers_on_project_id" ON "language_servers" ("project_id"); + CREATE TABLE IF NOT EXISTS "room_participants" ( "id" SERIAL PRIMARY KEY, "room_id" INTEGER NOT NULL REFERENCES rooms (id), diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 88b6f20953a9c3019bb2771127832de3f1d85eb9..6db4ad101b35170554433f2e71f52021fddbf60f 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -10,11 +10,7 @@ use sqlx::{ types::Uuid, FromRow, }; -use std::{ - future::Future, - path::{Path, PathBuf}, - time::Duration, -}; +use std::{future::Future, path::Path, time::Duration}; use time::{OffsetDateTime, PrimitiveDateTime}; #[cfg(test)] @@ -1443,13 +1439,17 @@ where for worktree in worktrees { sqlx::query( " - INSERT INTO worktrees (id, project_id, root_name) - VALUES ($1, $2, $3) + INSERT INTO worktrees (project_id, id, root_name, abs_path, visible, scan_id, is_complete) + VALUES ($1, $2, $3, $4, $5, $6, $7) ", ) - .bind(worktree.id as i32) .bind(project_id) + .bind(worktree.id as i32) .bind(&worktree.root_name) + .bind(&*String::from_utf8_lossy(&worktree.abs_path)) + .bind(worktree.visible) + .bind(0) + .bind(false) .execute(&mut tx) .await?; } @@ -1502,32 +1502,36 @@ where for worktree in worktrees { sqlx::query( " - INSERT INTO worktrees (project_id, id, root_name) - VALUES ($1, $2, $3) + INSERT INTO worktrees (project_id, id, root_name, abs_path, visible, scan_id, is_complete) + VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT (project_id, id) DO UPDATE SET root_name = excluded.root_name ", ) .bind(project_id) .bind(worktree.id as i32) .bind(&worktree.root_name) + .bind(String::from_utf8_lossy(&worktree.abs_path).as_ref()) + .bind(worktree.visible) + .bind(0) + .bind(false) .execute(&mut tx) .await?; } - let mut params = "?,".repeat(worktrees.len()); + let mut params = "(?, ?),".repeat(worktrees.len()); if !worktrees.is_empty() { params.pop(); } let query = format!( " DELETE FROM worktrees - WHERE id NOT IN ({params}) + WHERE (project_id, id) NOT IN ({params}) ", ); let mut query = sqlx::query(&query); for worktree in worktrees { - query = query.bind(worktree.id as i32); + query = query.bind(project_id).bind(WorktreeId(worktree.id as i32)); } query.execute(&mut tx).await?; @@ -1556,7 +1560,7 @@ where &self, project_id: ProjectId, connection_id: ConnectionId, - ) -> Result<(Project, i32)> { + ) -> Result<(Project, ReplicaId)> { self.transact(|mut tx| async move { let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>( " @@ -1574,7 +1578,7 @@ where " SELECT 1 FROM projects - WHERE project_id = $1 AND room_id = $2 + WHERE id = $1 AND room_id = $2 ", ) .bind(project_id) @@ -1582,9 +1586,9 @@ where .fetch_one(&mut tx) .await?; - let replica_ids = sqlx::query_scalar::<_, i32>( + let mut collaborators = sqlx::query_as::<_, ProjectCollaborator>( " - SELECT replica_id + SELECT * FROM project_collaborators WHERE project_id = $1 ", @@ -1592,11 +1596,21 @@ where .bind(project_id) .fetch_all(&mut tx) .await?; - let replica_ids = HashSet::from_iter(replica_ids); - let mut replica_id = 1; + let replica_ids = collaborators + .iter() + .map(|c| c.replica_id) + .collect::>(); + let mut replica_id = ReplicaId(1); while replica_ids.contains(&replica_id) { - replica_id += 1; + replica_id.0 += 1; } + let new_collaborator = ProjectCollaborator { + project_id, + connection_id: connection_id.0 as i32, + user_id, + replica_id, + is_host: false, + }; sqlx::query( " @@ -1610,51 +1624,140 @@ where VALUES ($1, $2, $3, $4, $5) ", ) - .bind(project_id) - .bind(connection_id.0 as i32) - .bind(user_id) - .bind(replica_id) - .bind(false) + .bind(new_collaborator.project_id) + .bind(new_collaborator.connection_id) + .bind(new_collaborator.user_id) + .bind(new_collaborator.replica_id) + .bind(new_collaborator.is_host) .execute(&mut tx) .await?; + collaborators.push(new_collaborator); + + let worktree_rows = sqlx::query_as::<_, WorktreeRow>( + " + SELECT * + FROM worktrees + WHERE project_id = $1 + ", + ) + .bind(project_id) + .fetch_all(&mut tx) + .await?; + let mut worktrees = worktree_rows + .into_iter() + .map(|worktree_row| { + ( + worktree_row.id, + Worktree { + id: worktree_row.id, + abs_path: worktree_row.abs_path, + root_name: worktree_row.root_name, + visible: worktree_row.visible, + entries: Default::default(), + diagnostic_summaries: Default::default(), + scan_id: worktree_row.scan_id as u64, + is_complete: worktree_row.is_complete, + }, + ) + }) + .collect::>(); + + let mut params = "(?, ?),".repeat(worktrees.len()); + if !worktrees.is_empty() { + params.pop(); + } + + // Populate worktree entries. + { + let query = format!( + " + SELECT * + FROM worktree_entries + WHERE (project_id, worktree_id) IN ({params}) + ", + ); + let mut entries = sqlx::query_as::<_, WorktreeEntry>(&query); + for worktree_id in worktrees.keys() { + entries = entries.bind(project_id).bind(*worktree_id); + } + let mut entries = entries.fetch(&mut tx); + while let Some(entry) = entries.next().await { + let entry = entry?; + if let Some(worktree) = worktrees.get_mut(&entry.worktree_id) { + worktree.entries.push(proto::Entry { + id: entry.id as u64, + is_dir: entry.is_dir, + path: entry.path.into_bytes(), + inode: entry.inode as u64, + mtime: Some(proto::Timestamp { + seconds: entry.mtime_seconds as u64, + nanos: entry.mtime_nanos as u32, + }), + is_symlink: entry.is_symlink, + is_ignored: entry.is_ignored, + }); + } + } + } + + // Populate worktree diagnostic summaries. + { + let query = format!( + " + SELECT * + FROM worktree_diagnostic_summaries + WHERE (project_id, worktree_id) IN ({params}) + ", + ); + let mut summaries = sqlx::query_as::<_, WorktreeDiagnosticSummary>(&query); + for worktree_id in worktrees.keys() { + summaries = summaries.bind(project_id).bind(*worktree_id); + } + let mut summaries = summaries.fetch(&mut tx); + while let Some(summary) = summaries.next().await { + let summary = summary?; + if let Some(worktree) = worktrees.get_mut(&summary.worktree_id) { + worktree + .diagnostic_summaries + .push(proto::DiagnosticSummary { + path: summary.path, + language_server_id: summary.language_server_id as u64, + error_count: summary.error_count as u32, + warning_count: summary.warning_count as u32, + }); + } + } + } + + // Populate language servers. + let language_servers = sqlx::query_as::<_, LanguageServer>( + " + SELECT * + FROM language_servers + WHERE project_id = $1 + ", + ) + .bind(project_id) + .fetch_all(&mut tx) + .await?; tx.commit().await?; - todo!() + Ok(( + Project { + collaborators, + worktrees, + language_servers: language_servers + .into_iter() + .map(|language_server| proto::LanguageServer { + id: language_server.id.to_proto(), + name: language_server.name, + }) + .collect(), + }, + replica_id as ReplicaId, + )) }) .await - // sqlx::query( - // " - // SELECT replica_id - // FROM project_collaborators - // WHERE project_id = $ - // ", - // ) - // .bind(project_id) - // .bind(connection_id.0 as i32) - // .bind(user_id) - // .bind(0) - // .bind(true) - // .execute(&mut tx) - // .await?; - // sqlx::query( - // " - // INSERT INTO project_collaborators ( - // project_id, - // connection_id, - // user_id, - // replica_id, - // is_host - // ) - // VALUES ($1, $2, $3, $4, $5) - // ", - // ) - // .bind(project_id) - // .bind(connection_id.0 as i32) - // .bind(user_id) - // .bind(0) - // .bind(true) - // .execute(&mut tx) - // .await?; } pub async fn unshare_project(&self, project_id: ProjectId) -> Result<()> { @@ -2089,32 +2192,72 @@ pub struct Room { id_type!(ProjectId); pub struct Project { - pub id: ProjectId, pub collaborators: Vec, - pub worktrees: BTreeMap, + pub worktrees: BTreeMap, pub language_servers: Vec, } +id_type!(ReplicaId); #[derive(Clone, Debug, Default, FromRow, PartialEq)] pub struct ProjectCollaborator { pub project_id: ProjectId, pub connection_id: i32, pub user_id: UserId, - pub replica_id: i32, + pub replica_id: ReplicaId, pub is_host: bool, } -#[derive(Default)] +id_type!(WorktreeId); +#[derive(Clone, Debug, Default, FromRow, PartialEq)] +struct WorktreeRow { + pub id: WorktreeId, + pub abs_path: String, + pub root_name: String, + pub visible: bool, + pub scan_id: i64, + pub is_complete: bool, +} + pub struct Worktree { - pub abs_path: PathBuf, + pub id: WorktreeId, + pub abs_path: String, pub root_name: String, pub visible: bool, - pub entries: BTreeMap, - pub diagnostic_summaries: BTreeMap, + pub entries: Vec, + pub diagnostic_summaries: Vec, pub scan_id: u64, pub is_complete: bool, } +#[derive(Clone, Debug, Default, FromRow, PartialEq)] +struct WorktreeEntry { + id: i64, + worktree_id: WorktreeId, + is_dir: bool, + path: String, + inode: i64, + mtime_seconds: i64, + mtime_nanos: i32, + is_symlink: bool, + is_ignored: bool, +} + +#[derive(Clone, Debug, Default, FromRow, PartialEq)] +struct WorktreeDiagnosticSummary { + worktree_id: WorktreeId, + path: String, + language_server_id: i64, + error_count: i32, + warning_count: i32, +} + +id_type!(LanguageServerId); +#[derive(Clone, Debug, Default, FromRow, PartialEq)] +struct LanguageServer { + id: LanguageServerId, + name: String, +} + pub struct LeftProject { pub id: ProjectId, pub host_user_id: UserId, diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index b54f03ce53e0aa6200814f8db9f1fc67744b718a..1236af42cb05af4b544f74166284d34aa3e44739 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -1,5 +1,5 @@ use crate::{ - db::{NewUserParams, ProjectId, SqliteTestDb as TestDb, UserId}, + db::{NewUserParams, SqliteTestDb as TestDb, UserId}, rpc::{Executor, Server}, AppState, }; @@ -2401,12 +2401,6 @@ async fn test_collaborating_with_diagnostics( // Wait for server to see the diagnostics update. deterministic.run_until_parked(); - { - let store = server.store.lock().await; - let project = store.project(ProjectId::from_proto(project_id)).unwrap(); - let worktree = project.worktrees.get(&worktree_id.to_proto()).unwrap(); - assert!(!worktree.diagnostic_summaries.is_empty()); - } // Ensure client B observes the new diagnostics. project_b.read_with(cx_b, |project, cx| { diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 02d8f25f38af2464ba63076da6cb11ed6ee28225..3c7d4ec61b6e07be2bdbd61274d52548afb4cb77 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -42,7 +42,6 @@ use std::{ marker::PhantomData, net::SocketAddr, ops::{Deref, DerefMut}, - os::unix::prelude::OsStrExt, rc::Rc, sync::{ atomic::{AtomicBool, Ordering::SeqCst}, @@ -930,16 +929,8 @@ impl Server { ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); let guest_user_id = request.sender_user_id; - let host_user_id; - let host_connection_id; - { - let state = self.store().await; - let project = state.project(project_id)?; - host_user_id = project.host.user_id; - host_connection_id = project.host_connection_id; - }; - tracing::info!(%project_id, %host_user_id, %host_connection_id, "join project"); + tracing::info!(%project_id, "join project"); let (project, replica_id) = self .app_state @@ -952,7 +943,7 @@ impl Server { .iter() .map(|collaborator| proto::Collaborator { peer_id: collaborator.connection_id as u32, - replica_id: collaborator.replica_id as u32, + replica_id: collaborator.replica_id.0 as u32, user_id: collaborator.user_id.to_proto(), }) .collect::>(); @@ -960,10 +951,10 @@ impl Server { .worktrees .iter() .map(|(id, worktree)| proto::WorktreeMetadata { - id: *id, + id: id.to_proto(), root_name: worktree.root_name.clone(), visible: worktree.visible, - abs_path: worktree.abs_path.as_os_str().as_bytes().to_vec(), + abs_path: worktree.abs_path.as_bytes().to_vec(), }) .collect::>(); @@ -977,7 +968,7 @@ impl Server { project_id: project_id.to_proto(), collaborator: Some(proto::Collaborator { peer_id: request.sender_connection_id.0, - replica_id: replica_id as u32, + replica_id: replica_id.0 as u32, user_id: guest_user_id.to_proto(), }), }, @@ -989,12 +980,12 @@ impl Server { // First, we send the metadata associated with each worktree. response.send(proto::JoinProjectResponse { worktrees: worktrees.clone(), - replica_id: replica_id as u32, + replica_id: replica_id.0 as u32, collaborators: collaborators.clone(), language_servers: project.language_servers.clone(), })?; - for (worktree_id, worktree) in &project.worktrees { + for (worktree_id, worktree) in project.worktrees { #[cfg(any(test, feature = "test-support"))] const MAX_CHUNK_SIZE: usize = 2; #[cfg(not(any(test, feature = "test-support")))] @@ -1003,10 +994,10 @@ impl Server { // Stream this worktree's entries. let message = proto::UpdateWorktree { project_id: project_id.to_proto(), - worktree_id: *worktree_id, - abs_path: worktree.abs_path.as_os_str().as_bytes().to_vec(), - root_name: worktree.root_name.clone(), - updated_entries: worktree.entries.values().cloned().collect(), + worktree_id: worktree_id.to_proto(), + abs_path: worktree.abs_path.as_bytes().to_vec(), + root_name: worktree.root_name, + updated_entries: worktree.entries, removed_entries: Default::default(), scan_id: worktree.scan_id, is_last_update: worktree.is_complete, @@ -1017,13 +1008,13 @@ impl Server { } // Stream this worktree's diagnostics. - for summary in worktree.diagnostic_summaries.values() { + for summary in worktree.diagnostic_summaries { self.peer.send( request.sender_connection_id, proto::UpdateDiagnosticSummary { project_id: project_id.to_proto(), - worktree_id: *worktree_id, - summary: Some(summary.clone()), + worktree_id: worktree.id.to_proto(), + summary: Some(summary), }, )?; } diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index 4be93547889683d75a7439fb98673ad4532e308a..a93182d50bf6363622d5fe90248f594600adb1d3 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -294,49 +294,6 @@ impl Store { Err(anyhow!("no such project"))? } - pub fn join_project( - &mut self, - requester_connection_id: ConnectionId, - project_id: ProjectId, - ) -> Result<(&Project, ReplicaId)> { - let connection = self - .connections - .get_mut(&requester_connection_id) - .ok_or_else(|| anyhow!("no such connection"))?; - let user = self - .connected_users - .get(&connection.user_id) - .ok_or_else(|| anyhow!("no such connection"))?; - let active_call = user.active_call.ok_or_else(|| anyhow!("no such project"))?; - anyhow::ensure!( - active_call.connection_id == Some(requester_connection_id), - "no such project" - ); - - let project = self - .projects - .get_mut(&project_id) - .ok_or_else(|| anyhow!("no such project"))?; - anyhow::ensure!(project.room_id == active_call.room_id, "no such project"); - - connection.projects.insert(project_id); - let mut replica_id = 1; - while project.active_replica_ids.contains(&replica_id) { - replica_id += 1; - } - project.active_replica_ids.insert(replica_id); - project.guests.insert( - requester_connection_id, - Collaborator { - replica_id, - user_id: connection.user_id, - admin: connection.admin, - }, - ); - - Ok((project, replica_id)) - } - pub fn leave_project( &mut self, project_id: ProjectId, @@ -409,12 +366,6 @@ impl Store { .connection_ids()) } - pub fn project(&self, project_id: ProjectId) -> Result<&Project> { - self.projects - .get(&project_id) - .ok_or_else(|| anyhow!("no such project")) - } - pub fn read_project( &self, project_id: ProjectId, diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index e688cad1f8e01a6c1926712438a16c85927b5d60..8aed5ef5cf5cfc5f5b3e2375e3bd6595edf85801 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -282,13 +282,6 @@ message UpdateWorktree { bytes abs_path = 8; } -message UpdateWorktreeExtensions { - uint64 project_id = 1; - uint64 worktree_id = 2; - repeated string extensions = 3; - repeated uint32 counts = 4; -} - message CreateProjectEntry { uint64 project_id = 1; uint64 worktree_id = 2; From 4b1dcf2d55002ded81dfccc4ed93193c51be184c Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 15 Nov 2022 16:46:17 +0100 Subject: [PATCH 024/109] Always use strings to represent paths over the wire Previously, the protocol used a mix of strings and bytes without any consistency. When we go to multiple platforms, we won't be able to mix encodings of paths anyway. We don't know this is the right approach, but it at least makes things consistent and easy to read in the database, on the wire, etc. Really, we should be using entry ids etc to refer to entries on the wire anyway, but there's a chance this is the wrong decision. Co-Authored-By: Nathan Sobo --- crates/call/src/room.rs | 4 ++-- crates/collab/src/db.rs | 6 +++--- crates/collab/src/rpc.rs | 4 ++-- crates/project/src/project.rs | 14 ++++++-------- crates/project/src/project_tests.rs | 6 +++++- crates/project/src/worktree.rs | 21 ++++++++------------- crates/rpc/proto/zed.proto | 12 ++++++------ 7 files changed, 32 insertions(+), 35 deletions(-) diff --git a/crates/call/src/room.rs b/crates/call/src/room.rs index 4ba8d8effc4831599bb0e358a37fe535b3220f16..8c1b0d9de09f42ecf48e10d67c31b1a6b5508350 100644 --- a/crates/call/src/room.rs +++ b/crates/call/src/room.rs @@ -10,7 +10,7 @@ use gpui::{AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext use live_kit_client::{LocalTrackPublication, LocalVideoTrack, RemoteVideoTrackUpdate}; use postage::stream::Stream; use project::Project; -use std::{mem, os::unix::prelude::OsStrExt, sync::Arc}; +use std::{mem, sync::Arc}; use util::{post_inc, ResultExt}; #[derive(Clone, Debug, PartialEq, Eq)] @@ -553,7 +553,7 @@ impl Room { id: worktree.id().to_proto(), root_name: worktree.root_name().into(), visible: worktree.is_visible(), - abs_path: worktree.abs_path().as_os_str().as_bytes().to_vec(), + abs_path: worktree.abs_path().to_string_lossy().into(), } }) .collect(), diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 6db4ad101b35170554433f2e71f52021fddbf60f..4cd3ce3a7c6317172398e22c38908b4e8334f475 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1446,7 +1446,7 @@ where .bind(project_id) .bind(worktree.id as i32) .bind(&worktree.root_name) - .bind(&*String::from_utf8_lossy(&worktree.abs_path)) + .bind(&worktree.abs_path) .bind(worktree.visible) .bind(0) .bind(false) @@ -1510,7 +1510,7 @@ where .bind(project_id) .bind(worktree.id as i32) .bind(&worktree.root_name) - .bind(String::from_utf8_lossy(&worktree.abs_path).as_ref()) + .bind(&worktree.abs_path) .bind(worktree.visible) .bind(0) .bind(false) @@ -1687,7 +1687,7 @@ where worktree.entries.push(proto::Entry { id: entry.id as u64, is_dir: entry.is_dir, - path: entry.path.into_bytes(), + path: entry.path, inode: entry.inode as u64, mtime: Some(proto::Timestamp { seconds: entry.mtime_seconds as u64, diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 3c7d4ec61b6e07be2bdbd61274d52548afb4cb77..5fcb8d5f9c1e1fcfd797d486652a98ba8c43a3f2 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -954,7 +954,7 @@ impl Server { id: id.to_proto(), root_name: worktree.root_name.clone(), visible: worktree.visible, - abs_path: worktree.abs_path.as_bytes().to_vec(), + abs_path: worktree.abs_path.clone(), }) .collect::>(); @@ -995,7 +995,7 @@ impl Server { let message = proto::UpdateWorktree { project_id: project_id.to_proto(), worktree_id: worktree_id.to_proto(), - abs_path: worktree.abs_path.as_bytes().to_vec(), + abs_path: worktree.abs_path.clone(), root_name: worktree.root_name, updated_entries: worktree.entries, removed_entries: Default::default(), diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index c59b19de8fe2774a3d9b1c6b80a529e40d850c3b..9ac10d14062edfb556e44c66d79423ab98d12aac 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -44,12 +44,10 @@ use std::{ cell::RefCell, cmp::{self, Ordering}, convert::TryInto, - ffi::OsString, hash::Hash, mem, num::NonZeroU32, ops::Range, - os::unix::{ffi::OsStrExt, prelude::OsStringExt}, path::{Component, Path, PathBuf}, rc::Rc, str, @@ -837,7 +835,7 @@ impl Project { .request(proto::CreateProjectEntry { worktree_id: project_path.worktree_id.to_proto(), project_id, - path: project_path.path.as_os_str().as_bytes().to_vec(), + path: project_path.path.to_string_lossy().into(), is_directory, }) .await?; @@ -881,7 +879,7 @@ impl Project { .request(proto::CopyProjectEntry { project_id, entry_id: entry_id.to_proto(), - new_path: new_path.as_os_str().as_bytes().to_vec(), + new_path: new_path.to_string_lossy().into(), }) .await?; let entry = response @@ -924,7 +922,7 @@ impl Project { .request(proto::RenameProjectEntry { project_id, entry_id: entry_id.to_proto(), - new_path: new_path.as_os_str().as_bytes().to_vec(), + new_path: new_path.to_string_lossy().into(), }) .await?; let entry = response @@ -4606,7 +4604,7 @@ impl Project { let entry = worktree .update(&mut cx, |worktree, cx| { let worktree = worktree.as_local_mut().unwrap(); - let path = PathBuf::from(OsString::from_vec(envelope.payload.path)); + let path = PathBuf::from(envelope.payload.path); worktree.create_entry(path, envelope.payload.is_directory, cx) }) .await?; @@ -4630,7 +4628,7 @@ impl Project { let worktree_scan_id = worktree.read_with(&cx, |worktree, _| worktree.scan_id()); let entry = worktree .update(&mut cx, |worktree, cx| { - let new_path = PathBuf::from(OsString::from_vec(envelope.payload.new_path)); + let new_path = PathBuf::from(envelope.payload.new_path); worktree .as_local_mut() .unwrap() @@ -4658,7 +4656,7 @@ impl Project { let worktree_scan_id = worktree.read_with(&cx, |worktree, _| worktree.scan_id()); let entry = worktree .update(&mut cx, |worktree, cx| { - let new_path = PathBuf::from(OsString::from_vec(envelope.payload.new_path)); + let new_path = PathBuf::from(envelope.payload.new_path); worktree .as_local_mut() .unwrap() diff --git a/crates/project/src/project_tests.rs b/crates/project/src/project_tests.rs index ca274b18b8a37f74d9587470c2a9877d900505e8..77d2a610d5378756ee199097f242ba5a8ec535ad 100644 --- a/crates/project/src/project_tests.rs +++ b/crates/project/src/project_tests.rs @@ -2166,7 +2166,11 @@ async fn test_rescan_and_remote_updates( proto::WorktreeMetadata { id: initial_snapshot.id().to_proto(), root_name: initial_snapshot.root_name().into(), - abs_path: initial_snapshot.abs_path().as_os_str().as_bytes().to_vec(), + abs_path: initial_snapshot + .abs_path() + .as_os_str() + .to_string_lossy() + .into(), visible: true, }, rpc.clone(), diff --git a/crates/project/src/worktree.rs b/crates/project/src/worktree.rs index 9e4ec3ffb9a236e8b9b13c871269833e225fa1b3..ddd4a7a6c847998fec8564e147b9f4ff30fa2177 100644 --- a/crates/project/src/worktree.rs +++ b/crates/project/src/worktree.rs @@ -40,7 +40,6 @@ use std::{ future::Future, mem, ops::{Deref, DerefMut}, - os::unix::prelude::{OsStrExt, OsStringExt}, path::{Path, PathBuf}, sync::{atomic::AtomicUsize, Arc}, task::Poll, @@ -221,7 +220,7 @@ impl Worktree { let root_name = worktree.root_name.clone(); let visible = worktree.visible; - let abs_path = PathBuf::from(OsString::from_vec(worktree.abs_path)); + let abs_path = PathBuf::from(worktree.abs_path); let snapshot = Snapshot { id: WorktreeId(remote_id as usize), abs_path: Arc::from(abs_path.deref()), @@ -656,7 +655,7 @@ impl LocalWorktree { id: self.id().to_proto(), root_name: self.root_name().to_string(), visible: self.visible, - abs_path: self.abs_path().as_os_str().as_bytes().to_vec(), + abs_path: self.abs_path().as_os_str().to_string_lossy().into(), } } @@ -990,7 +989,7 @@ impl LocalWorktree { let update = proto::UpdateWorktree { project_id, worktree_id, - abs_path: snapshot.abs_path().as_os_str().as_bytes().to_vec(), + abs_path: snapshot.abs_path().to_string_lossy().into(), root_name: snapshot.root_name().to_string(), updated_entries: snapshot .entries_by_path @@ -1381,7 +1380,7 @@ impl LocalSnapshot { proto::UpdateWorktree { project_id, worktree_id: self.id().to_proto(), - abs_path: self.abs_path().as_os_str().as_bytes().to_vec(), + abs_path: self.abs_path().to_string_lossy().into(), root_name, updated_entries: self.entries_by_path.iter().map(Into::into).collect(), removed_entries: Default::default(), @@ -1449,7 +1448,7 @@ impl LocalSnapshot { proto::UpdateWorktree { project_id, worktree_id, - abs_path: self.abs_path().as_os_str().as_bytes().to_vec(), + abs_path: self.abs_path().to_string_lossy().into(), root_name: self.root_name().to_string(), updated_entries, removed_entries, @@ -2928,7 +2927,7 @@ impl<'a> From<&'a Entry> for proto::Entry { Self { id: entry.id.to_proto(), is_dir: entry.is_dir(), - path: entry.path.as_os_str().as_bytes().to_vec(), + path: entry.path.to_string_lossy().into(), inode: entry.inode, mtime: Some(entry.mtime.into()), is_symlink: entry.is_symlink, @@ -2946,14 +2945,10 @@ impl<'a> TryFrom<(&'a CharBag, proto::Entry)> for Entry { EntryKind::Dir } else { let mut char_bag = *root_char_bag; - char_bag.extend( - String::from_utf8_lossy(&entry.path) - .chars() - .map(|c| c.to_ascii_lowercase()), - ); + char_bag.extend(entry.path.chars().map(|c| c.to_ascii_lowercase())); EntryKind::File(char_bag) }; - let path: Arc = PathBuf::from(OsString::from_vec(entry.path)).into(); + let path: Arc = PathBuf::from(entry.path).into(); Ok(Entry { id: ProjectEntryId::from_proto(entry.id), kind, diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index 8aed5ef5cf5cfc5f5b3e2375e3bd6595edf85801..30c1c89e8f8b393f96e13c96ad9ea42e14ff7a7e 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -279,26 +279,26 @@ message UpdateWorktree { repeated uint64 removed_entries = 5; uint64 scan_id = 6; bool is_last_update = 7; - bytes abs_path = 8; + string abs_path = 8; } message CreateProjectEntry { uint64 project_id = 1; uint64 worktree_id = 2; - bytes path = 3; + string path = 3; bool is_directory = 4; } message RenameProjectEntry { uint64 project_id = 1; uint64 entry_id = 2; - bytes new_path = 3; + string new_path = 3; } message CopyProjectEntry { uint64 project_id = 1; uint64 entry_id = 2; - bytes new_path = 3; + string new_path = 3; } message DeleteProjectEntry { @@ -884,7 +884,7 @@ message File { message Entry { uint64 id = 1; bool is_dir = 2; - bytes path = 3; + string path = 3; uint64 inode = 4; Timestamp mtime = 5; bool is_symlink = 6; @@ -1068,7 +1068,7 @@ message WorktreeMetadata { uint64 id = 1; string root_name = 2; bool visible = 3; - bytes abs_path = 4; + string abs_path = 4; } message UpdateDiffBase { From e9eadcaa6a61247f59d5bba629e5db64bfeef49f Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 15 Nov 2022 17:18:28 +0100 Subject: [PATCH 025/109] Move `Store::update_worktree` to `Db::update_worktree` --- .../20221109000000_test_schema.sql | 12 +- crates/collab/src/db.rs | 126 ++++++++++++++++++ crates/collab/src/rpc.rs | 17 +-- crates/collab/src/rpc/store.rs | 51 +------ 4 files changed, 139 insertions(+), 67 deletions(-) diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index cffb549a891cb97e83bf16a428d4b8a9a57669d1..24edd69d31b09ce7f8547d616d49ccd8e452adf1 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -61,8 +61,8 @@ CREATE INDEX "index_worktrees_on_project_id" ON "worktrees" ("project_id"); CREATE TABLE "worktree_entries" ( "id" INTEGER NOT NULL, - "project_id" INTEGER NOT NULL REFERENCES projects (id), - "worktree_id" INTEGER NOT NULL REFERENCES worktrees (id), + "project_id" INTEGER NOT NULL, + "worktree_id" INTEGER NOT NULL, "is_dir" BOOL NOT NULL, "path" VARCHAR NOT NULL, "inode" INTEGER NOT NULL, @@ -71,17 +71,19 @@ CREATE TABLE "worktree_entries" ( "is_symlink" BOOL NOT NULL, "is_ignored" BOOL NOT NULL, PRIMARY KEY(project_id, worktree_id, id) + FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ); CREATE INDEX "index_worktree_entries_on_project_id_and_worktree_id" ON "worktree_entries" ("project_id", "worktree_id"); CREATE TABLE "worktree_diagnostic_summaries" ( "path" VARCHAR NOT NULL, - "project_id" INTEGER NOT NULL REFERENCES projects (id), - "worktree_id" INTEGER NOT NULL REFERENCES worktrees (id), + "project_id" INTEGER NOT NULL, + "worktree_id" INTEGER NOT NULL, "language_server_id" INTEGER NOT NULL, "error_count" INTEGER NOT NULL, "warning_count" INTEGER NOT NULL, - PRIMARY KEY(project_id, worktree_id, path) + PRIMARY KEY(project_id, worktree_id, path), + FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ); CREATE INDEX "index_worktree_diagnostic_summaries_on_project_id_and_worktree_id" ON "worktree_diagnostic_summaries" ("project_id", "worktree_id"); diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 4cd3ce3a7c6317172398e22c38908b4e8334f475..d61cdd334d7db9c2a9b26a287c4ad653d98306d3 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1556,6 +1556,132 @@ where .await } + pub async fn update_worktree( + &self, + update: &proto::UpdateWorktree, + connection_id: ConnectionId, + ) -> Result> { + self.transact(|mut tx| async move { + let project_id = ProjectId::from_proto(update.project_id); + let worktree_id = WorktreeId::from_proto(update.worktree_id); + + // Ensure the update comes from the host. + sqlx::query( + " + SELECT 1 + FROM projects + WHERE id = $1 AND host_connection_id = $2 + ", + ) + .bind(project_id) + .bind(connection_id.0 as i32) + .fetch_one(&mut tx) + .await?; + + // Update metadata. + sqlx::query( + " + UPDATE worktrees + SET + root_name = $1, + scan_id = $2, + is_complete = $3, + abs_path = $4 + WHERE project_id = $5 AND id = $6 + RETURNING 1 + ", + ) + .bind(&update.root_name) + .bind(update.scan_id as i64) + .bind(update.is_last_update) + .bind(&update.abs_path) + .bind(project_id) + .bind(worktree_id) + .fetch_one(&mut tx) + .await?; + + if !update.updated_entries.is_empty() { + let mut params = + "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?),".repeat(update.updated_entries.len()); + params.pop(); + + let query = format!( + " + INSERT INTO worktree_entries ( + project_id, + worktree_id, + id, + is_dir, + path, + inode, + mtime_seconds, + mtime_nanos, + is_symlink, + is_ignored + ) + VALUES {params} + " + ); + let mut query = sqlx::query(&query); + for entry in &update.updated_entries { + let mtime = entry.mtime.clone().unwrap_or_default(); + query = query + .bind(project_id) + .bind(worktree_id) + .bind(entry.id as i64) + .bind(entry.is_dir) + .bind(&entry.path) + .bind(entry.inode as i64) + .bind(mtime.seconds as i64) + .bind(mtime.nanos as i32) + .bind(entry.is_symlink) + .bind(entry.is_ignored); + } + query.execute(&mut tx).await?; + } + + if !update.removed_entries.is_empty() { + let mut params = "(?, ?, ?),".repeat(update.removed_entries.len()); + params.pop(); + let query = format!( + " + DELETE FROM worktree_entries + WHERE (project_id, worktree_id, entry_id) IN ({params}) + " + ); + + let mut query = sqlx::query(&query); + for entry_id in &update.removed_entries { + query = query + .bind(project_id) + .bind(worktree_id) + .bind(*entry_id as i64); + } + query.execute(&mut tx).await?; + } + + let connection_ids = sqlx::query_scalar::<_, i32>( + " + SELECT connection_id + FROM project_collaborators + WHERE project_id = $1 AND connection_id != $2 + ", + ) + .bind(project_id) + .bind(connection_id.0 as i32) + .fetch_all(&mut tx) + .await?; + + tx.commit().await?; + + Ok(connection_ids + .into_iter() + .map(|connection_id| ConnectionId(connection_id as u32)) + .collect()) + }) + .await + } + pub async fn join_project( &self, project_id: ProjectId, diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 5fcb8d5f9c1e1fcfd797d486652a98ba8c43a3f2..1943f18ceb43ec1c5b74a109e1aab03578295aa8 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1105,18 +1105,11 @@ impl Server { request: Message, response: Response, ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); - let worktree_id = request.payload.worktree_id; - let connection_ids = self.store().await.update_worktree( - request.sender_connection_id, - project_id, - worktree_id, - &request.payload.root_name, - &request.payload.removed_entries, - &request.payload.updated_entries, - request.payload.scan_id, - request.payload.is_last_update, - )?; + let connection_ids = self + .app_state + .db + .update_worktree(&request.payload, request.sender_connection_id) + .await?; broadcast( request.sender_connection_id, diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index a93182d50bf6363622d5fe90248f594600adb1d3..e3abc8dd3c04c9392a962a842205ef0a01ba1180 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -3,7 +3,7 @@ use anyhow::{anyhow, Result}; use collections::{btree_map, BTreeMap, BTreeSet, HashMap, HashSet}; use rpc::{proto, ConnectionId}; use serde::Serialize; -use std::{path::PathBuf, str}; +use std::path::PathBuf; use tracing::instrument; pub type RoomId = u64; @@ -325,37 +325,6 @@ impl Store { }) } - #[allow(clippy::too_many_arguments)] - pub fn update_worktree( - &mut self, - connection_id: ConnectionId, - project_id: ProjectId, - worktree_id: u64, - worktree_root_name: &str, - removed_entries: &[u64], - updated_entries: &[proto::Entry], - scan_id: u64, - is_last_update: bool, - ) -> Result> { - let project = self.write_project(project_id, connection_id)?; - - let connection_ids = project.connection_ids(); - let mut worktree = project.worktrees.entry(worktree_id).or_default(); - worktree.root_name = worktree_root_name.to_string(); - - for entry_id in removed_entries { - worktree.entries.remove(entry_id); - } - - for entry in updated_entries { - worktree.entries.insert(entry.id, entry.clone()); - } - - worktree.scan_id = scan_id; - worktree.is_complete = is_last_update; - Ok(connection_ids) - } - pub fn project_connection_ids( &self, project_id: ProjectId, @@ -384,24 +353,6 @@ impl Store { } } - fn write_project( - &mut self, - project_id: ProjectId, - connection_id: ConnectionId, - ) -> Result<&mut Project> { - let project = self - .projects - .get_mut(&project_id) - .ok_or_else(|| anyhow!("no such project"))?; - if project.host_connection_id == connection_id - || project.guests.contains_key(&connection_id) - { - Ok(project) - } else { - Err(anyhow!("no such project"))? - } - } - #[cfg(test)] pub fn check_invariants(&self) { for (connection_id, connection) in &self.connections { From ad67f5e4de5c086c1c42642f6d1656d8e599c344 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 15 Nov 2022 17:49:37 +0100 Subject: [PATCH 026/109] Always use the database to retrieve collaborators for a project --- crates/collab/src/db.rs | 58 +++++++++++ crates/collab/src/rpc.rs | 174 +++++++++++++++++++-------------- crates/collab/src/rpc/store.rs | 28 ------ 3 files changed, 160 insertions(+), 100 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index d61cdd334d7db9c2a9b26a287c4ad653d98306d3..e503188e1dc4621329119368ac5e6dd376d0a72d 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1886,6 +1886,64 @@ where .await } + pub async fn project_collaborators( + &self, + project_id: ProjectId, + connection_id: ConnectionId, + ) -> Result> { + self.transact(|mut tx| async move { + let collaborators = sqlx::query_as::<_, ProjectCollaborator>( + " + SELECT * + FROM project_collaborators + WHERE project_id = $1 + ", + ) + .bind(project_id) + .fetch_all(&mut tx) + .await?; + + if collaborators + .iter() + .any(|collaborator| collaborator.connection_id == connection_id.0 as i32) + { + Ok(collaborators) + } else { + Err(anyhow!("no such project"))? + } + }) + .await + } + + pub async fn project_connection_ids( + &self, + project_id: ProjectId, + connection_id: ConnectionId, + ) -> Result> { + self.transact(|mut tx| async move { + let connection_ids = sqlx::query_scalar::<_, i32>( + " + SELECT connection_id + FROM project_collaborators + WHERE project_id = $1 + ", + ) + .bind(project_id) + .fetch_all(&mut tx) + .await?; + + if connection_ids.contains(&(connection_id.0 as i32)) { + Ok(connection_ids + .into_iter() + .map(|connection_id| ConnectionId(connection_id as u32)) + .collect()) + } else { + Err(anyhow!("no such project"))? + } + }) + .await + } + pub async fn unshare_project(&self, project_id: ProjectId) -> Result<()> { todo!() // test_support!(self, { diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 1943f18ceb43ec1c5b74a109e1aab03578295aa8..f0116f04f9b3a27c3c0cd5378f1187364aa4f249 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1187,13 +1187,15 @@ impl Server { self: Arc, request: Message, ) -> Result<()> { - let receiver_ids = self.store().await.project_connection_ids( - ProjectId::from_proto(request.payload.project_id), - request.sender_connection_id, - )?; + let project_id = ProjectId::from_proto(request.payload.project_id); + let project_connection_ids = self + .app_state + .db + .project_connection_ids(project_id, request.sender_connection_id) + .await?; broadcast( request.sender_connection_id, - receiver_ids, + project_connection_ids, |connection_id| { self.peer.forward_send( request.sender_connection_id, @@ -1214,25 +1216,25 @@ impl Server { T: EntityMessage + RequestMessage, { let project_id = ProjectId::from_proto(request.payload.remote_entity_id()); - let host_connection_id = self - .store() - .await - .read_project(project_id, request.sender_connection_id)? - .host_connection_id; + let collaborators = self + .app_state + .db + .project_collaborators(project_id, request.sender_connection_id) + .await?; + let host = collaborators + .iter() + .find(|collaborator| collaborator.is_host) + .ok_or_else(|| anyhow!("host not found"))?; + let payload = self .peer .forward_request( request.sender_connection_id, - host_connection_id, + ConnectionId(host.connection_id as u32), request.payload, ) .await?; - // Ensure project still exists by the time we get the response from the host. - self.store() - .await - .read_project(project_id, request.sender_connection_id)?; - response.send(payload)?; Ok(()) } @@ -1243,25 +1245,39 @@ impl Server { response: Response, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); - let host = self - .store() - .await - .read_project(project_id, request.sender_connection_id)? - .host_connection_id; + let collaborators = self + .app_state + .db + .project_collaborators(project_id, request.sender_connection_id) + .await?; + let host = collaborators + .into_iter() + .find(|collaborator| collaborator.is_host) + .ok_or_else(|| anyhow!("host not found"))?; + let host_connection_id = ConnectionId(host.connection_id as u32); let response_payload = self .peer - .forward_request(request.sender_connection_id, host, request.payload.clone()) + .forward_request( + request.sender_connection_id, + host_connection_id, + request.payload.clone(), + ) .await?; - let mut guests = self - .store() - .await - .read_project(project_id, request.sender_connection_id)? - .connection_ids(); - guests.retain(|guest_connection_id| *guest_connection_id != request.sender_connection_id); - broadcast(host, guests, |conn_id| { + let mut collaborators = self + .app_state + .db + .project_collaborators(project_id, request.sender_connection_id) + .await?; + collaborators.retain(|collaborator| { + collaborator.connection_id != request.sender_connection_id.0 as i32 + }); + let project_connection_ids = collaborators + .into_iter() + .map(|collaborator| ConnectionId(collaborator.connection_id as u32)); + broadcast(host_connection_id, project_connection_ids, |conn_id| { self.peer - .forward_send(host, conn_id, response_payload.clone()) + .forward_send(host_connection_id, conn_id, response_payload.clone()) }); response.send(response_payload)?; Ok(()) @@ -1285,14 +1301,15 @@ impl Server { response: Response, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); - let receiver_ids = { - let store = self.store().await; - store.project_connection_ids(project_id, request.sender_connection_id)? - }; + let project_connection_ids = self + .app_state + .db + .project_connection_ids(project_id, request.sender_connection_id) + .await?; broadcast( request.sender_connection_id, - receiver_ids, + project_connection_ids, |connection_id| { self.peer.forward_send( request.sender_connection_id, @@ -1309,13 +1326,16 @@ impl Server { self: Arc, request: Message, ) -> Result<()> { - let receiver_ids = self.store().await.project_connection_ids( - ProjectId::from_proto(request.payload.project_id), - request.sender_connection_id, - )?; + let project_id = ProjectId::from_proto(request.payload.project_id); + let project_connection_ids = self + .app_state + .db + .project_connection_ids(project_id, request.sender_connection_id) + .await?; + broadcast( request.sender_connection_id, - receiver_ids, + project_connection_ids, |connection_id| { self.peer.forward_send( request.sender_connection_id, @@ -1331,13 +1351,15 @@ impl Server { self: Arc, request: Message, ) -> Result<()> { - let receiver_ids = self.store().await.project_connection_ids( - ProjectId::from_proto(request.payload.project_id), - request.sender_connection_id, - )?; + let project_id = ProjectId::from_proto(request.payload.project_id); + let project_connection_ids = self + .app_state + .db + .project_connection_ids(project_id, request.sender_connection_id) + .await?; broadcast( request.sender_connection_id, - receiver_ids, + project_connection_ids, |connection_id| { self.peer.forward_send( request.sender_connection_id, @@ -1350,13 +1372,15 @@ impl Server { } async fn buffer_saved(self: Arc, request: Message) -> Result<()> { - let receiver_ids = self.store().await.project_connection_ids( - ProjectId::from_proto(request.payload.project_id), - request.sender_connection_id, - )?; + let project_id = ProjectId::from_proto(request.payload.project_id); + let project_connection_ids = self + .app_state + .db + .project_connection_ids(project_id, request.sender_connection_id) + .await?; broadcast( request.sender_connection_id, - receiver_ids, + project_connection_ids, |connection_id| { self.peer.forward_send( request.sender_connection_id, @@ -1376,14 +1400,14 @@ impl Server { let project_id = ProjectId::from_proto(request.payload.project_id); let leader_id = ConnectionId(request.payload.leader_id); let follower_id = request.sender_connection_id; - { - let store = self.store().await; - if !store - .project_connection_ids(project_id, follower_id)? - .contains(&leader_id) - { - Err(anyhow!("no such peer"))?; - } + let project_connection_ids = self + .app_state + .db + .project_connection_ids(project_id, request.sender_connection_id) + .await?; + + if !project_connection_ids.contains(&leader_id) { + Err(anyhow!("no such peer"))?; } let mut response_payload = self @@ -1400,11 +1424,12 @@ impl Server { async fn unfollow(self: Arc, request: Message) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); let leader_id = ConnectionId(request.payload.leader_id); - let store = self.store().await; - if !store - .project_connection_ids(project_id, request.sender_connection_id)? - .contains(&leader_id) - { + let project_connection_ids = self + .app_state + .db + .project_connection_ids(project_id, request.sender_connection_id) + .await?; + if !project_connection_ids.contains(&leader_id) { Err(anyhow!("no such peer"))?; } self.peer @@ -1417,9 +1442,12 @@ impl Server { request: Message, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); - let store = self.store().await; - let connection_ids = - store.project_connection_ids(project_id, request.sender_connection_id)?; + let project_connection_ids = self + .app_state + .db + .project_connection_ids(project_id, request.sender_connection_id) + .await?; + let leader_id = request .payload .variant @@ -1431,7 +1459,7 @@ impl Server { }); for follower_id in &request.payload.follower_ids { let follower_id = ConnectionId(*follower_id); - if connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id { + if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id { self.peer.forward_send( request.sender_connection_id, follower_id, @@ -1629,13 +1657,15 @@ impl Server { self: Arc, request: Message, ) -> Result<()> { - let receiver_ids = self.store().await.project_connection_ids( - ProjectId::from_proto(request.payload.project_id), - request.sender_connection_id, - )?; + let project_id = ProjectId::from_proto(request.payload.project_id); + let project_connection_ids = self + .app_state + .db + .project_connection_ids(project_id, request.sender_connection_id) + .await?; broadcast( request.sender_connection_id, - receiver_ids, + project_connection_ids, |connection_id| { self.peer.forward_send( request.sender_connection_id, diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index e3abc8dd3c04c9392a962a842205ef0a01ba1180..f694440a50b2a62345cb69382ef123f4cd39e320 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -325,34 +325,6 @@ impl Store { }) } - pub fn project_connection_ids( - &self, - project_id: ProjectId, - acting_connection_id: ConnectionId, - ) -> Result> { - Ok(self - .read_project(project_id, acting_connection_id)? - .connection_ids()) - } - - pub fn read_project( - &self, - project_id: ProjectId, - connection_id: ConnectionId, - ) -> Result<&Project> { - let project = self - .projects - .get(&project_id) - .ok_or_else(|| anyhow!("no such project"))?; - if project.host_connection_id == connection_id - || project.guests.contains_key(&connection_id) - { - Ok(project) - } else { - Err(anyhow!("no such project"))? - } - } - #[cfg(test)] pub fn check_invariants(&self) { for (connection_id, connection) in &self.connections { From 0817f905a2baf20b034844beb38459a63916ccc2 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 15 Nov 2022 18:02:07 +0100 Subject: [PATCH 027/109] Fix syntax error in schema --- .../20221109000000_test_schema.sql | 2 +- .../20221111092550_reconnection_support.sql | 40 ++++++++++--------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 24edd69d31b09ce7f8547d616d49ccd8e452adf1..ccb09af454a2ed10c29146ca42b5ceb0a30f6084 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -70,7 +70,7 @@ CREATE TABLE "worktree_entries" ( "mtime_nanos" INTEGER NOT NULL, "is_symlink" BOOL NOT NULL, "is_ignored" BOOL NOT NULL, - PRIMARY KEY(project_id, worktree_id, id) + PRIMARY KEY(project_id, worktree_id, id), FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ); CREATE INDEX "index_worktree_entries_on_project_id_and_worktree_id" ON "worktree_entries" ("project_id", "worktree_id"); diff --git a/crates/collab/migrations/20221111092550_reconnection_support.sql b/crates/collab/migrations/20221111092550_reconnection_support.sql index a5b49ad7636ef5e4aa398a31d199bf7e49bc5dd4..e0e594d46e588f7e0125374f096f0f37f8bbfa9a 100644 --- a/crates/collab/migrations/20221111092550_reconnection_support.sql +++ b/crates/collab/migrations/20221111092550_reconnection_support.sql @@ -9,17 +9,6 @@ ALTER TABLE "projects" ADD "host_connection_id" INTEGER, DROP COLUMN "unregistered"; -CREATE TABLE "project_collaborators" ( - "id" SERIAL PRIMARY KEY, - "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, - "connection_id" INTEGER NOT NULL, - "user_id" INTEGER NOT NULL, - "replica_id" INTEGER NOT NULL, - "is_host" BOOLEAN NOT NULL -); -CREATE INDEX "index_project_collaborators_on_project_id" ON "project_collaborators" ("project_id"); -CREATE UNIQUE INDEX "index_project_collaborators_on_project_id_and_replica_id" ON "project_collaborators" ("project_id", "replica_id"); - CREATE TABLE "worktrees" ( "id" INTEGER NOT NULL, "project_id" INTEGER NOT NULL REFERENCES projects (id), @@ -34,8 +23,8 @@ CREATE INDEX "index_worktrees_on_project_id" ON "worktrees" ("project_id"); CREATE TABLE "worktree_entries" ( "id" INTEGER NOT NULL, - "project_id" INTEGER NOT NULL REFERENCES projects (id), - "worktree_id" INTEGER NOT NULL REFERENCES worktrees (id), + "project_id" INTEGER NOT NULL, + "worktree_id" INTEGER NOT NULL, "is_dir" BOOL NOT NULL, "path" VARCHAR NOT NULL, "inode" INTEGER NOT NULL, @@ -43,18 +32,20 @@ CREATE TABLE "worktree_entries" ( "mtime_nanos" INTEGER NOT NULL, "is_symlink" BOOL NOT NULL, "is_ignored" BOOL NOT NULL, - PRIMARY KEY(project_id, worktree_id, id) + PRIMARY KEY(project_id, worktree_id, id), + FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ); CREATE INDEX "index_worktree_entries_on_project_id_and_worktree_id" ON "worktree_entries" ("project_id", "worktree_id"); CREATE TABLE "worktree_diagnostic_summaries" ( "path" VARCHAR NOT NULL, - "project_id" INTEGER NOT NULL REFERENCES projects (id), - "worktree_id" INTEGER NOT NULL REFERENCES worktrees (id), + "project_id" INTEGER NOT NULL, + "worktree_id" INTEGER NOT NULL, "language_server_id" INTEGER NOT NULL, "error_count" INTEGER NOT NULL, "warning_count" INTEGER NOT NULL, - PRIMARY KEY(project_id, worktree_id, path) + PRIMARY KEY(project_id, worktree_id, path), + FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ); CREATE INDEX "index_worktree_diagnostic_summaries_on_project_id_and_worktree_id" ON "worktree_diagnostic_summaries" ("project_id", "worktree_id"); @@ -66,8 +57,19 @@ CREATE TABLE "language_servers" ( ); CREATE INDEX "index_language_servers_on_project_id" ON "language_servers" ("project_id"); -CREATE TABLE IF NOT EXISTS "room_participants" ( - "id" SERIAL PRIMARY KEY, +CREATE TABLE "project_collaborators" ( + "id" INTEGER PRIMARY KEY, + "project_id" INTEGER NOT NULL REFERENCES projects (id), + "connection_id" INTEGER NOT NULL, + "user_id" INTEGER NOT NULL, + "replica_id" INTEGER NOT NULL, + "is_host" BOOLEAN NOT NULL +); +CREATE INDEX "index_project_collaborators_on_project_id" ON "project_collaborators" ("project_id"); +CREATE UNIQUE INDEX "index_project_collaborators_on_project_id_and_replica_id" ON "project_collaborators" ("project_id", "replica_id"); + +CREATE TABLE "room_participants" ( + "id" INTEGER PRIMARY KEY, "room_id" INTEGER NOT NULL REFERENCES rooms (id), "user_id" INTEGER NOT NULL REFERENCES users (id), "answering_connection_id" INTEGER, From 31902363968c95ebb364d42ebac2871a3761bbc7 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 16 Nov 2022 08:57:19 +0100 Subject: [PATCH 028/109] Update worktree entry instead of erroring when it already exists --- crates/collab/src/db.rs | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index e503188e1dc4621329119368ac5e6dd376d0a72d..44cc382ee0027ab22491debf8923a05cf79075f7 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1620,6 +1620,14 @@ where is_ignored ) VALUES {params} + ON CONFLICT (project_id, worktree_id, id) DO UPDATE SET + is_dir = excluded.is_dir, + path = excluded.path, + inode = excluded.inode, + mtime_seconds = excluded.mtime_seconds, + mtime_nanos = excluded.mtime_nanos, + is_symlink = excluded.is_symlink, + is_ignored = excluded.is_ignored " ); let mut query = sqlx::query(&query); From c151c87e12e58d3dd121857ccaeaa267a99bec52 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 16 Nov 2022 10:36:48 +0100 Subject: [PATCH 029/109] Correctly leave projects when leaving room --- crates/collab/src/db.rs | 108 ++++++++++++++++++++++++++------------- crates/collab/src/rpc.rs | 72 +++++++++++++------------- 2 files changed, 107 insertions(+), 73 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 44cc382ee0027ab22491debf8923a05cf79075f7..78b6547ef2d92460e9cb91f26358e8607984e6f9 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1171,44 +1171,68 @@ where .fetch_all(&mut tx) .await?; - let mut project_collaborators = sqlx::query_as::<_, ProjectCollaborator>( + let project_ids = sqlx::query_scalar::<_, ProjectId>( " - SELECT project_collaborators.* - FROM projects, project_collaborators - WHERE - projects.room_id = $1 AND - projects.id = project_collaborators.project_id AND - project_collaborators.connection_id = $2 + SELECT project_id + FROM project_collaborators + WHERE connection_id = $1 ", ) - .bind(room_id) .bind(connection_id.0 as i32) - .fetch(&mut tx); + .fetch_all(&mut tx) + .await?; + // Leave projects. let mut left_projects = HashMap::default(); - while let Some(collaborator) = project_collaborators.next().await { - let collaborator = collaborator?; - let left_project = - left_projects - .entry(collaborator.project_id) - .or_insert(LeftProject { - id: collaborator.project_id, - host_user_id: Default::default(), - connection_ids: Default::default(), - }); - - let collaborator_connection_id = - ConnectionId(collaborator.connection_id as u32); - if collaborator_connection_id != connection_id || collaborator.is_host { - left_project.connection_ids.push(collaborator_connection_id); + if !project_ids.is_empty() { + let mut params = "?,".repeat(project_ids.len()); + params.pop(); + let query = format!( + " + SELECT * + FROM project_collaborators + WHERE project_id IN ({params}) + " + ); + let mut query = sqlx::query_as::<_, ProjectCollaborator>(&query); + for project_id in project_ids { + query = query.bind(project_id); } - if collaborator.is_host { - left_project.host_user_id = collaborator.user_id; + let mut project_collaborators = query.fetch(&mut tx); + while let Some(collaborator) = project_collaborators.next().await { + let collaborator = collaborator?; + let left_project = + left_projects + .entry(collaborator.project_id) + .or_insert(LeftProject { + id: collaborator.project_id, + host_user_id: Default::default(), + connection_ids: Default::default(), + }); + + let collaborator_connection_id = + ConnectionId(collaborator.connection_id as u32); + if collaborator_connection_id != connection_id { + left_project.connection_ids.push(collaborator_connection_id); + } + + if collaborator.is_host { + left_project.host_user_id = collaborator.user_id; + } } } - drop(project_collaborators); + sqlx::query( + " + DELETE FROM project_collaborators + WHERE connection_id = $1 + ", + ) + .bind(connection_id.0 as i32) + .execute(&mut tx) + .await?; + // Unshare projects. sqlx::query( " DELETE FROM projects @@ -1265,15 +1289,16 @@ where sqlx::query( " UPDATE room_participants - SET location_kind = $1 AND location_project_id = $2 + SET location_kind = $1, location_project_id = $2 WHERE room_id = $3 AND answering_connection_id = $4 + RETURNING 1 ", ) .bind(location_kind) .bind(location_project_id) .bind(room_id) .bind(connection_id.0 as i32) - .execute(&mut tx) + .fetch_one(&mut tx) .await?; self.commit_room_transaction(room_id, tx).await @@ -1335,21 +1360,32 @@ where let ( user_id, answering_connection_id, - _location_kind, - _location_project_id, + location_kind, + location_project_id, calling_user_id, initial_project_id, ) = participant?; if let Some(answering_connection_id) = answering_connection_id { + let location = match (location_kind, location_project_id) { + (Some(0), Some(project_id)) => { + Some(proto::participant_location::Variant::SharedProject( + proto::participant_location::SharedProject { + id: project_id.to_proto(), + }, + )) + } + (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject( + Default::default(), + )), + _ => Some(proto::participant_location::Variant::External( + Default::default(), + )), + }; participants.push(proto::Participant { user_id: user_id.to_proto(), peer_id: answering_connection_id as u32, projects: Default::default(), - location: Some(proto::ParticipantLocation { - variant: Some(proto::participant_location::Variant::External( - Default::default(), - )), - }), + location: Some(proto::ParticipantLocation { variant: location }), }); } else { pending_participants.push(proto::PendingParticipant { diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index f0116f04f9b3a27c3c0cd5378f1187364aa4f249..9f7d21a1a93b4a780ec1ce008c68ba5bc2560867 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -624,19 +624,19 @@ impl Server { async fn leave_room_for_connection( self: &Arc, - connection_id: ConnectionId, - user_id: UserId, + leaving_connection_id: ConnectionId, + leaving_user_id: UserId, ) -> Result<()> { let mut contacts_to_update = HashSet::default(); - let Some(left_room) = self.app_state.db.leave_room_for_connection(connection_id).await? else { + let Some(left_room) = self.app_state.db.leave_room_for_connection(leaving_connection_id).await? else { return Err(anyhow!("no room to leave"))?; }; - contacts_to_update.insert(user_id); + contacts_to_update.insert(leaving_user_id); for project in left_room.left_projects.into_values() { - if project.host_user_id == user_id { - for connection_id in project.connection_ids { + for connection_id in project.connection_ids { + if project.host_user_id == leaving_user_id { self.peer .send( connection_id, @@ -645,29 +645,27 @@ impl Server { }, ) .trace_err(); - } - } else { - for connection_id in project.connection_ids { + } else { self.peer .send( connection_id, proto::RemoveProjectCollaborator { project_id: project.id.to_proto(), - peer_id: connection_id.0, + peer_id: leaving_connection_id.0, }, ) .trace_err(); } - - self.peer - .send( - connection_id, - proto::UnshareProject { - project_id: project.id.to_proto(), - }, - ) - .trace_err(); } + + self.peer + .send( + leaving_connection_id, + proto::UnshareProject { + project_id: project.id.to_proto(), + }, + ) + .trace_err(); } self.room_updated(&left_room.room); @@ -691,7 +689,7 @@ impl Server { live_kit .remove_participant( left_room.room.live_kit_room.clone(), - connection_id.to_string(), + leaving_connection_id.to_string(), ) .await .trace_err(); @@ -941,6 +939,9 @@ impl Server { let collaborators = project .collaborators .iter() + .filter(|collaborator| { + collaborator.connection_id != request.sender_connection_id.0 as i32 + }) .map(|collaborator| proto::Collaborator { peer_id: collaborator.connection_id as u32, replica_id: collaborator.replica_id.0 as u32, @@ -958,23 +959,20 @@ impl Server { }) .collect::>(); - for collaborator in &project.collaborators { - let connection_id = ConnectionId(collaborator.connection_id as u32); - if connection_id != request.sender_connection_id { - self.peer - .send( - connection_id, - proto::AddProjectCollaborator { - project_id: project_id.to_proto(), - collaborator: Some(proto::Collaborator { - peer_id: request.sender_connection_id.0, - replica_id: replica_id.0 as u32, - user_id: guest_user_id.to_proto(), - }), - }, - ) - .trace_err(); - } + for collaborator in &collaborators { + self.peer + .send( + ConnectionId(collaborator.peer_id), + proto::AddProjectCollaborator { + project_id: project_id.to_proto(), + collaborator: Some(proto::Collaborator { + peer_id: request.sender_connection_id.0, + replica_id: replica_id.0 as u32, + user_id: guest_user_id.to_proto(), + }), + }, + ) + .trace_err(); } // First, we send the metadata associated with each worktree. From f9567ae1166559df1bf4d66397159c24c46a3d15 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 16 Nov 2022 10:41:36 +0100 Subject: [PATCH 030/109] Cascade deletes when project is deleted --- .../migrations.sqlite/20221109000000_test_schema.sql | 10 +++++----- .../migrations/20221111092550_reconnection_support.sql | 10 +++++----- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index ccb09af454a2ed10c29146ca42b5ceb0a30f6084..9914831bbab02e12b8947c6fe01459720fe52717 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -49,7 +49,7 @@ CREATE TABLE "projects" ( CREATE TABLE "worktrees" ( "id" INTEGER NOT NULL, - "project_id" INTEGER NOT NULL REFERENCES projects (id), + "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, "root_name" VARCHAR NOT NULL, "abs_path" VARCHAR NOT NULL, "visible" BOOL NOT NULL, @@ -71,7 +71,7 @@ CREATE TABLE "worktree_entries" ( "is_symlink" BOOL NOT NULL, "is_ignored" BOOL NOT NULL, PRIMARY KEY(project_id, worktree_id, id), - FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) + FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE ); CREATE INDEX "index_worktree_entries_on_project_id_and_worktree_id" ON "worktree_entries" ("project_id", "worktree_id"); @@ -83,13 +83,13 @@ CREATE TABLE "worktree_diagnostic_summaries" ( "error_count" INTEGER NOT NULL, "warning_count" INTEGER NOT NULL, PRIMARY KEY(project_id, worktree_id, path), - FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) + FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE ); CREATE INDEX "index_worktree_diagnostic_summaries_on_project_id_and_worktree_id" ON "worktree_diagnostic_summaries" ("project_id", "worktree_id"); CREATE TABLE "language_servers" ( "id" INTEGER NOT NULL, - "project_id" INTEGER NOT NULL REFERENCES projects (id), + "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, "name" VARCHAR NOT NULL, PRIMARY KEY(project_id, id) ); @@ -97,7 +97,7 @@ CREATE INDEX "index_language_servers_on_project_id" ON "language_servers" ("proj CREATE TABLE "project_collaborators" ( "id" INTEGER PRIMARY KEY, - "project_id" INTEGER NOT NULL REFERENCES projects (id), + "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, "connection_id" INTEGER NOT NULL, "user_id" INTEGER NOT NULL, "replica_id" INTEGER NOT NULL, diff --git a/crates/collab/migrations/20221111092550_reconnection_support.sql b/crates/collab/migrations/20221111092550_reconnection_support.sql index e0e594d46e588f7e0125374f096f0f37f8bbfa9a..8cd53726fdc704fe168eed77c81f021583f74fef 100644 --- a/crates/collab/migrations/20221111092550_reconnection_support.sql +++ b/crates/collab/migrations/20221111092550_reconnection_support.sql @@ -11,7 +11,7 @@ ALTER TABLE "projects" CREATE TABLE "worktrees" ( "id" INTEGER NOT NULL, - "project_id" INTEGER NOT NULL REFERENCES projects (id), + "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, "root_name" VARCHAR NOT NULL, "abs_path" VARCHAR NOT NULL, "visible" BOOL NOT NULL, @@ -33,7 +33,7 @@ CREATE TABLE "worktree_entries" ( "is_symlink" BOOL NOT NULL, "is_ignored" BOOL NOT NULL, PRIMARY KEY(project_id, worktree_id, id), - FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) + FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE ); CREATE INDEX "index_worktree_entries_on_project_id_and_worktree_id" ON "worktree_entries" ("project_id", "worktree_id"); @@ -45,13 +45,13 @@ CREATE TABLE "worktree_diagnostic_summaries" ( "error_count" INTEGER NOT NULL, "warning_count" INTEGER NOT NULL, PRIMARY KEY(project_id, worktree_id, path), - FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) + FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE ); CREATE INDEX "index_worktree_diagnostic_summaries_on_project_id_and_worktree_id" ON "worktree_diagnostic_summaries" ("project_id", "worktree_id"); CREATE TABLE "language_servers" ( "id" INTEGER NOT NULL, - "project_id" INTEGER NOT NULL REFERENCES projects (id), + "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, "name" VARCHAR NOT NULL, PRIMARY KEY(project_id, id) ); @@ -59,7 +59,7 @@ CREATE INDEX "index_language_servers_on_project_id" ON "language_servers" ("proj CREATE TABLE "project_collaborators" ( "id" INTEGER PRIMARY KEY, - "project_id" INTEGER NOT NULL REFERENCES projects (id), + "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, "connection_id" INTEGER NOT NULL, "user_id" INTEGER NOT NULL, "replica_id" INTEGER NOT NULL, From eeb32fa88809f04a1b40730ce1fedb171de7b551 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 16 Nov 2022 11:07:39 +0100 Subject: [PATCH 031/109] Improve queries for composite primary keys --- crates/collab/src/db.rs | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 78b6547ef2d92460e9cb91f26358e8607984e6f9..785965905ad19bacaea426e281ae383654b9ae4c 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1554,20 +1554,20 @@ where .await?; } - let mut params = "(?, ?),".repeat(worktrees.len()); + let mut params = "?,".repeat(worktrees.len()); if !worktrees.is_empty() { params.pop(); } let query = format!( " DELETE FROM worktrees - WHERE (project_id, id) NOT IN ({params}) + WHERE project_id = ? AND worktree_id NOT IN ({params}) ", ); - let mut query = sqlx::query(&query); + let mut query = sqlx::query(&query).bind(project_id); for worktree in worktrees { - query = query.bind(project_id).bind(WorktreeId(worktree.id as i32)); + query = query.bind(WorktreeId(worktree.id as i32)); } query.execute(&mut tx).await?; @@ -1685,21 +1685,18 @@ where } if !update.removed_entries.is_empty() { - let mut params = "(?, ?, ?),".repeat(update.removed_entries.len()); + let mut params = "?,".repeat(update.removed_entries.len()); params.pop(); let query = format!( " DELETE FROM worktree_entries - WHERE (project_id, worktree_id, entry_id) IN ({params}) + WHERE project_id = ? AND worktree_id = ? AND entry_id IN ({params}) " ); - let mut query = sqlx::query(&query); + let mut query = sqlx::query(&query).bind(project_id).bind(worktree_id); for entry_id in &update.removed_entries { - query = query - .bind(project_id) - .bind(worktree_id) - .bind(*entry_id as i64); + query = query.bind(*entry_id as i64); } query.execute(&mut tx).await?; } @@ -1832,7 +1829,7 @@ where }) .collect::>(); - let mut params = "(?, ?),".repeat(worktrees.len()); + let mut params = "?,".repeat(worktrees.len()); if !worktrees.is_empty() { params.pop(); } @@ -1843,12 +1840,12 @@ where " SELECT * FROM worktree_entries - WHERE (project_id, worktree_id) IN ({params}) + WHERE project_id = ? AND worktree_id IN ({params}) ", ); - let mut entries = sqlx::query_as::<_, WorktreeEntry>(&query); + let mut entries = sqlx::query_as::<_, WorktreeEntry>(&query).bind(project_id); for worktree_id in worktrees.keys() { - entries = entries.bind(project_id).bind(*worktree_id); + entries = entries.bind(*worktree_id); } let mut entries = entries.fetch(&mut tx); while let Some(entry) = entries.next().await { @@ -1876,12 +1873,13 @@ where " SELECT * FROM worktree_diagnostic_summaries - WHERE (project_id, worktree_id) IN ({params}) + WHERE project_id = $1 AND worktree_id IN ({params}) ", ); - let mut summaries = sqlx::query_as::<_, WorktreeDiagnosticSummary>(&query); + let mut summaries = + sqlx::query_as::<_, WorktreeDiagnosticSummary>(&query).bind(project_id); for worktree_id in worktrees.keys() { - summaries = summaries.bind(project_id).bind(*worktree_id); + summaries = summaries.bind(*worktree_id); } let mut summaries = summaries.fetch(&mut tx); while let Some(summary) = summaries.next().await { From 117458f4f6c567ae691d4a1b716a6f5e8daef717 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 16 Nov 2022 14:58:11 +0100 Subject: [PATCH 032/109] Send worktree updates after project metadata has been sent --- crates/collab/src/db.rs | 3 +- crates/project/src/project.rs | 75 +++++++++++++++++++++---------- crates/workspace/src/workspace.rs | 5 ++- 3 files changed, 56 insertions(+), 27 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 785965905ad19bacaea426e281ae383654b9ae4c..f058d3bfe159c11266ab5183e7dfd32b53fbdce8 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1561,7 +1561,7 @@ where let query = format!( " DELETE FROM worktrees - WHERE project_id = ? AND worktree_id NOT IN ({params}) + WHERE project_id = ? AND id NOT IN ({params}) ", ); @@ -1580,6 +1580,7 @@ where WHERE project_id = $1 AND is_host = FALSE ", ) + .bind(project_id) .fetch(&mut tx); while let Some(connection_id) = db_guest_connection_ids.next().await { guest_connection_ids.push(ConnectionId(connection_id? as u32)); diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 9ac10d14062edfb556e44c66d79423ab98d12aac..436b2d92a26e713de120ea084c86062aff5a45e6 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -10,7 +10,11 @@ use anyhow::{anyhow, Context, Result}; use client::{proto, Client, PeerId, TypedEnvelope, UserStore}; use clock::ReplicaId; use collections::{hash_map, BTreeMap, HashMap, HashSet}; -use futures::{future::Shared, AsyncWriteExt, Future, FutureExt, StreamExt, TryFutureExt}; +use futures::{ + channel::{mpsc, oneshot}, + future::Shared, + AsyncWriteExt, Future, FutureExt, StreamExt, TryFutureExt, +}; use gpui::{ AnyModelHandle, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, @@ -145,7 +149,7 @@ enum WorktreeHandle { enum ProjectClientState { Local { remote_id: u64, - metadata_changed: watch::Sender<()>, + metadata_changed: mpsc::UnboundedSender>, _maintain_metadata: Task<()>, _detect_unshare: Task>, }, @@ -533,7 +537,7 @@ impl Project { nonce: StdRng::from_entropy().gen(), }; for worktree in worktrees { - this.add_worktree(&worktree, cx); + let _ = this.add_worktree(&worktree, cx); } this }); @@ -728,14 +732,22 @@ impl Project { } } - fn metadata_changed(&mut self, cx: &mut ModelContext) { + fn metadata_changed(&mut self, cx: &mut ModelContext) -> impl Future { + let (tx, rx) = oneshot::channel(); if let Some(ProjectClientState::Local { metadata_changed, .. }) = &mut self.client_state { - *metadata_changed.borrow_mut() = (); + let _ = metadata_changed.unbounded_send(tx); } cx.notify(); + + async move { + // If the project is shared, this will resolve when the `_maintain_metadata` task has + // a chance to update the metadata. Otherwise, it will resolve right away because `tx` + // will get dropped. + let _ = rx.await; + } } pub fn collaborators(&self) -> &HashMap { @@ -1025,17 +1037,22 @@ impl Project { self.client_subscriptions .push(self.client.add_model_for_remote_entity(project_id, cx)); - self.metadata_changed(cx); + let _ = self.metadata_changed(cx); cx.emit(Event::RemoteIdChanged(Some(project_id))); cx.notify(); let mut status = self.client.status(); - let (metadata_changed_tx, mut metadata_changed_rx) = watch::channel(); + let (metadata_changed_tx, mut metadata_changed_rx) = mpsc::unbounded(); self.client_state = Some(ProjectClientState::Local { remote_id: project_id, metadata_changed: metadata_changed_tx, _maintain_metadata: cx.spawn_weak(move |this, cx| async move { - while let Some(()) = metadata_changed_rx.next().await { + while let Some(tx) = metadata_changed_rx.next().await { + let mut txs = vec![tx]; + while let Ok(Some(next_tx)) = metadata_changed_rx.try_next() { + txs.push(next_tx); + } + let Some(this) = this.upgrade(&cx) else { break }; this.read_with(&cx, |this, cx| { let worktrees = this @@ -1054,6 +1071,10 @@ impl Project { }) .await .log_err(); + + for tx in txs { + let _ = tx.send(()); + } } }), _detect_unshare: cx.spawn_weak(move |this, mut cx| { @@ -1105,7 +1126,7 @@ impl Project { } } - self.metadata_changed(cx); + let _ = self.metadata_changed(cx); cx.notify(); self.client.send(proto::UnshareProject { project_id: remote_id, @@ -4162,12 +4183,13 @@ impl Project { }); let worktree = worktree?; - let project_id = project.update(&mut cx, |project, cx| { - project.add_worktree(&worktree, cx); - project.remote_id() - }); + project + .update(&mut cx, |project, cx| project.add_worktree(&worktree, cx)) + .await; - if let Some(project_id) = project_id { + if let Some(project_id) = + project.read_with(&cx, |project, _| project.remote_id()) + { worktree .update(&mut cx, |worktree, cx| { worktree.as_local_mut().unwrap().share(project_id, cx) @@ -4191,7 +4213,11 @@ impl Project { }) } - pub fn remove_worktree(&mut self, id_to_remove: WorktreeId, cx: &mut ModelContext) { + pub fn remove_worktree( + &mut self, + id_to_remove: WorktreeId, + cx: &mut ModelContext, + ) -> impl Future { self.worktrees.retain(|worktree| { if let Some(worktree) = worktree.upgrade(cx) { let id = worktree.read(cx).id(); @@ -4205,11 +4231,14 @@ impl Project { false } }); - self.metadata_changed(cx); - cx.notify(); + self.metadata_changed(cx) } - fn add_worktree(&mut self, worktree: &ModelHandle, cx: &mut ModelContext) { + fn add_worktree( + &mut self, + worktree: &ModelHandle, + cx: &mut ModelContext, + ) -> impl Future { cx.observe(worktree, |_, _, cx| cx.notify()).detach(); if worktree.read(cx).is_local() { cx.subscribe(worktree, |this, worktree, event, cx| match event { @@ -4233,15 +4262,13 @@ impl Project { .push(WorktreeHandle::Weak(worktree.downgrade())); } - self.metadata_changed(cx); cx.observe_release(worktree, |this, worktree, cx| { - this.remove_worktree(worktree.id(), cx); - cx.notify(); + let _ = this.remove_worktree(worktree.id(), cx); }) .detach(); cx.emit(Event::WorktreeAdded); - cx.notify(); + self.metadata_changed(cx) } fn update_local_worktree_buffers( @@ -4558,11 +4585,11 @@ impl Project { } else { let worktree = Worktree::remote(remote_id, replica_id, worktree, client.clone(), cx); - this.add_worktree(&worktree, cx); + let _ = this.add_worktree(&worktree, cx); } } - this.metadata_changed(cx); + let _ = this.metadata_changed(cx); for (id, _) in old_worktrees_by_id { cx.emit(Event::WorktreeRemoved(id)); } diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index 9db524ee9ba2b935d817ce64081d8ee374bb363a..2296741ed3c7f31768c2bd5857a463e18179c4fe 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -1531,7 +1531,8 @@ impl Workspace { RemoveWorktreeFromProject(worktree_id): &RemoveWorktreeFromProject, cx: &mut ViewContext, ) { - self.project + let _ = self + .project .update(cx, |project, cx| project.remove_worktree(*worktree_id, cx)); } @@ -3177,7 +3178,7 @@ mod tests { // Remove a project folder project.update(cx, |project, cx| { - project.remove_worktree(worktree_id, cx); + let _ = project.remove_worktree(worktree_id, cx); }); assert_eq!( cx.current_window_title(window_id).as_deref(), From 95369f92ebb91f645dfee1eccf0f981081ef50ab Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 16 Nov 2022 15:41:33 +0100 Subject: [PATCH 033/109] Move `Store::update_diagnostic_summary` to `Db` --- crates/collab/src/db.rs | 115 +++++++++++++++++++++++++-------- crates/collab/src/rpc.rs | 22 +++---- crates/collab/src/rpc/store.rs | 25 ------- 3 files changed, 97 insertions(+), 65 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index f058d3bfe159c11266ab5183e7dfd32b53fbdce8..3d913bb47d0ac51cf8267f607736ed40232d7bbc 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1724,6 +1724,81 @@ where .await } + pub async fn update_diagnostic_summary( + &self, + update: &proto::UpdateDiagnosticSummary, + connection_id: ConnectionId, + ) -> Result> { + self.transact(|mut tx| async { + let project_id = ProjectId::from_proto(update.project_id); + let worktree_id = WorktreeId::from_proto(update.worktree_id); + let summary = update + .summary + .as_ref() + .ok_or_else(|| anyhow!("invalid summary"))?; + + // Ensure the update comes from the host. + sqlx::query( + " + SELECT 1 + FROM projects + WHERE id = $1 AND host_connection_id = $2 + ", + ) + .bind(project_id) + .bind(connection_id.0 as i32) + .fetch_one(&mut tx) + .await?; + + // Update summary. + sqlx::query( + " + INSERT INTO worktree_diagnostic_summaries ( + project_id, + worktree_id, + path, + language_server_id, + error_count, + warning_count + ) + VALUES ($1, $2, $3, $4, $5, $6) + ON CONFLICT (project_id, worktree_id, path) DO UPDATE SET + language_server_id = excluded.language_server_id, + error_count = excluded.error_count, + warning_count = excluded.warning_count + ", + ) + .bind(project_id) + .bind(worktree_id) + .bind(&summary.path) + .bind(summary.language_server_id as i64) + .bind(summary.error_count as i32) + .bind(summary.warning_count as i32) + .execute(&mut tx) + .await?; + + let connection_ids = sqlx::query_scalar::<_, i32>( + " + SELECT connection_id + FROM project_collaborators + WHERE project_id = $1 AND connection_id != $2 + ", + ) + .bind(project_id) + .bind(connection_id.0 as i32) + .fetch_all(&mut tx) + .await?; + + tx.commit().await?; + + Ok(connection_ids + .into_iter() + .map(|connection_id| ConnectionId(connection_id as u32)) + .collect()) + }) + .await + } + pub async fn join_project( &self, project_id: ProjectId, @@ -1830,25 +1905,17 @@ where }) .collect::>(); - let mut params = "?,".repeat(worktrees.len()); - if !worktrees.is_empty() { - params.pop(); - } - // Populate worktree entries. { - let query = format!( + let mut entries = sqlx::query_as::<_, WorktreeEntry>( " - SELECT * - FROM worktree_entries - WHERE project_id = ? AND worktree_id IN ({params}) + SELECT * + FROM worktree_entries + WHERE project_id = $1 ", - ); - let mut entries = sqlx::query_as::<_, WorktreeEntry>(&query).bind(project_id); - for worktree_id in worktrees.keys() { - entries = entries.bind(*worktree_id); - } - let mut entries = entries.fetch(&mut tx); + ) + .bind(project_id) + .fetch(&mut tx); while let Some(entry) = entries.next().await { let entry = entry?; if let Some(worktree) = worktrees.get_mut(&entry.worktree_id) { @@ -1870,19 +1937,15 @@ where // Populate worktree diagnostic summaries. { - let query = format!( + let mut summaries = sqlx::query_as::<_, WorktreeDiagnosticSummary>( " - SELECT * - FROM worktree_diagnostic_summaries - WHERE project_id = $1 AND worktree_id IN ({params}) + SELECT * + FROM worktree_diagnostic_summaries + WHERE project_id = $1 ", - ); - let mut summaries = - sqlx::query_as::<_, WorktreeDiagnosticSummary>(&query).bind(project_id); - for worktree_id in worktrees.keys() { - summaries = summaries.bind(*worktree_id); - } - let mut summaries = summaries.fetch(&mut tx); + ) + .bind(project_id) + .fetch(&mut tx); while let Some(summary) = summaries.next().await { let summary = summary?; if let Some(worktree) = worktrees.get_mut(&summary.worktree_id) { diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 9f7d21a1a93b4a780ec1ce008c68ba5bc2560867..ac971f8f0359c3bd542ca8a6e28f0bb7f8bd694b 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1103,7 +1103,7 @@ impl Server { request: Message, response: Response, ) -> Result<()> { - let connection_ids = self + let guest_connection_ids = self .app_state .db .update_worktree(&request.payload, request.sender_connection_id) @@ -1111,7 +1111,7 @@ impl Server { broadcast( request.sender_connection_id, - connection_ids, + guest_connection_ids, |connection_id| { self.peer.forward_send( request.sender_connection_id, @@ -1128,21 +1128,15 @@ impl Server { self: Arc, request: Message, ) -> Result<()> { - let summary = request - .payload - .summary - .clone() - .ok_or_else(|| anyhow!("invalid summary"))?; - let receiver_ids = self.store().await.update_diagnostic_summary( - ProjectId::from_proto(request.payload.project_id), - request.payload.worktree_id, - request.sender_connection_id, - summary, - )?; + let guest_connection_ids = self + .app_state + .db + .update_diagnostic_summary(&request.payload, request.sender_connection_id) + .await?; broadcast( request.sender_connection_id, - receiver_ids, + guest_connection_ids, |connection_id| { self.peer.forward_send( request.sender_connection_id, diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index f694440a50b2a62345cb69382ef123f4cd39e320..1be778e83a789536d5f9e1ef2327707b8e2d966a 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -251,31 +251,6 @@ impl Store { } } - pub fn update_diagnostic_summary( - &mut self, - project_id: ProjectId, - worktree_id: u64, - connection_id: ConnectionId, - summary: proto::DiagnosticSummary, - ) -> Result> { - let project = self - .projects - .get_mut(&project_id) - .ok_or_else(|| anyhow!("no such project"))?; - if project.host_connection_id == connection_id { - let worktree = project - .worktrees - .get_mut(&worktree_id) - .ok_or_else(|| anyhow!("no such worktree"))?; - worktree - .diagnostic_summaries - .insert(summary.path.clone().into(), summary); - return Ok(project.connection_ids()); - } - - Err(anyhow!("no such worktree"))? - } - pub fn start_language_server( &mut self, project_id: ProjectId, From 9bc57c0c61df9e8c3cf6429fb530cac77dac7577 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 16 Nov 2022 15:48:26 +0100 Subject: [PATCH 034/109] Move `Store::start_language_server` to `Db` --- .../20221109000000_test_schema.sql | 6 +- .../20221111092550_reconnection_support.sql | 8 +-- crates/collab/src/db.rs | 62 +++++++++++++++++++ crates/collab/src/rpc.rs | 17 +++-- crates/collab/src/rpc/store.rs | 18 ------ 5 files changed, 76 insertions(+), 35 deletions(-) diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 9914831bbab02e12b8947c6fe01459720fe52717..66925fddd55fba36464eef2fab7b4f30af75362f 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -48,8 +48,8 @@ CREATE TABLE "projects" ( ); CREATE TABLE "worktrees" ( - "id" INTEGER NOT NULL, "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, + "id" INTEGER NOT NULL, "root_name" VARCHAR NOT NULL, "abs_path" VARCHAR NOT NULL, "visible" BOOL NOT NULL, @@ -60,9 +60,9 @@ CREATE TABLE "worktrees" ( CREATE INDEX "index_worktrees_on_project_id" ON "worktrees" ("project_id"); CREATE TABLE "worktree_entries" ( - "id" INTEGER NOT NULL, "project_id" INTEGER NOT NULL, "worktree_id" INTEGER NOT NULL, + "id" INTEGER NOT NULL, "is_dir" BOOL NOT NULL, "path" VARCHAR NOT NULL, "inode" INTEGER NOT NULL, @@ -76,9 +76,9 @@ CREATE TABLE "worktree_entries" ( CREATE INDEX "index_worktree_entries_on_project_id_and_worktree_id" ON "worktree_entries" ("project_id", "worktree_id"); CREATE TABLE "worktree_diagnostic_summaries" ( - "path" VARCHAR NOT NULL, "project_id" INTEGER NOT NULL, "worktree_id" INTEGER NOT NULL, + "path" VARCHAR NOT NULL, "language_server_id" INTEGER NOT NULL, "error_count" INTEGER NOT NULL, "warning_count" INTEGER NOT NULL, diff --git a/crates/collab/migrations/20221111092550_reconnection_support.sql b/crates/collab/migrations/20221111092550_reconnection_support.sql index 8cd53726fdc704fe168eed77c81f021583f74fef..4f4ad6aede8b2160c19d0899b505a9ca1c48b3aa 100644 --- a/crates/collab/migrations/20221111092550_reconnection_support.sql +++ b/crates/collab/migrations/20221111092550_reconnection_support.sql @@ -10,8 +10,8 @@ ALTER TABLE "projects" DROP COLUMN "unregistered"; CREATE TABLE "worktrees" ( - "id" INTEGER NOT NULL, "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, + "id" INTEGER NOT NULL, "root_name" VARCHAR NOT NULL, "abs_path" VARCHAR NOT NULL, "visible" BOOL NOT NULL, @@ -22,9 +22,9 @@ CREATE TABLE "worktrees" ( CREATE INDEX "index_worktrees_on_project_id" ON "worktrees" ("project_id"); CREATE TABLE "worktree_entries" ( - "id" INTEGER NOT NULL, "project_id" INTEGER NOT NULL, "worktree_id" INTEGER NOT NULL, + "id" INTEGER NOT NULL, "is_dir" BOOL NOT NULL, "path" VARCHAR NOT NULL, "inode" INTEGER NOT NULL, @@ -38,9 +38,9 @@ CREATE TABLE "worktree_entries" ( CREATE INDEX "index_worktree_entries_on_project_id_and_worktree_id" ON "worktree_entries" ("project_id", "worktree_id"); CREATE TABLE "worktree_diagnostic_summaries" ( - "path" VARCHAR NOT NULL, "project_id" INTEGER NOT NULL, "worktree_id" INTEGER NOT NULL, + "path" VARCHAR NOT NULL, "language_server_id" INTEGER NOT NULL, "error_count" INTEGER NOT NULL, "warning_count" INTEGER NOT NULL, @@ -50,8 +50,8 @@ CREATE TABLE "worktree_diagnostic_summaries" ( CREATE INDEX "index_worktree_diagnostic_summaries_on_project_id_and_worktree_id" ON "worktree_diagnostic_summaries" ("project_id", "worktree_id"); CREATE TABLE "language_servers" ( - "id" INTEGER NOT NULL, "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, + "id" INTEGER NOT NULL, "name" VARCHAR NOT NULL, PRIMARY KEY(project_id, id) ); diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 3d913bb47d0ac51cf8267f607736ed40232d7bbc..9163e71aa4dd1e9d5e209830c8c2ace21b82c3e9 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1799,6 +1799,68 @@ where .await } + pub async fn start_language_server( + &self, + update: &proto::StartLanguageServer, + connection_id: ConnectionId, + ) -> Result> { + self.transact(|mut tx| async { + let project_id = ProjectId::from_proto(update.project_id); + let server = update + .server + .as_ref() + .ok_or_else(|| anyhow!("invalid language server"))?; + + // Ensure the update comes from the host. + sqlx::query( + " + SELECT 1 + FROM projects + WHERE id = $1 AND host_connection_id = $2 + ", + ) + .bind(project_id) + .bind(connection_id.0 as i32) + .fetch_one(&mut tx) + .await?; + + // Add the newly-started language server. + sqlx::query( + " + INSERT INTO language_servers (project_id, id, name) + VALUES ($1, $2, $3) + ON CONFLICT (project_id, id) DO UPDATE SET + name = excluded.name + ", + ) + .bind(project_id) + .bind(server.id as i64) + .bind(&server.name) + .execute(&mut tx) + .await?; + + let connection_ids = sqlx::query_scalar::<_, i32>( + " + SELECT connection_id + FROM project_collaborators + WHERE project_id = $1 AND connection_id != $2 + ", + ) + .bind(project_id) + .bind(connection_id.0 as i32) + .fetch_all(&mut tx) + .await?; + + tx.commit().await?; + + Ok(connection_ids + .into_iter() + .map(|connection_id| ConnectionId(connection_id as u32)) + .collect()) + }) + .await + } + pub async fn join_project( &self, project_id: ProjectId, diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index ac971f8f0359c3bd542ca8a6e28f0bb7f8bd694b..5e3018160c85b24a51bf04587f880d22008df8e4 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1152,18 +1152,15 @@ impl Server { self: Arc, request: Message, ) -> Result<()> { - let receiver_ids = self.store().await.start_language_server( - ProjectId::from_proto(request.payload.project_id), - request.sender_connection_id, - request - .payload - .server - .clone() - .ok_or_else(|| anyhow!("invalid language server"))?, - )?; + let guest_connection_ids = self + .app_state + .db + .start_language_server(&request.payload, request.sender_connection_id) + .await?; + broadcast( request.sender_connection_id, - receiver_ids, + guest_connection_ids, |connection_id| { self.peer.forward_send( request.sender_connection_id, diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index 1be778e83a789536d5f9e1ef2327707b8e2d966a..57dd726d3facb9a8b3186b7833540c6cfe6f31fc 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -251,24 +251,6 @@ impl Store { } } - pub fn start_language_server( - &mut self, - project_id: ProjectId, - connection_id: ConnectionId, - language_server: proto::LanguageServer, - ) -> Result> { - let project = self - .projects - .get_mut(&project_id) - .ok_or_else(|| anyhow!("no such project"))?; - if project.host_connection_id == connection_id { - project.language_servers.push(language_server); - return Ok(project.connection_ids()); - } - - Err(anyhow!("no such project"))? - } - pub fn leave_project( &mut self, project_id: ProjectId, From faf265328e9adc46423766f9275a7a7a668a99de Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 16 Nov 2022 16:03:01 +0100 Subject: [PATCH 035/109] Wait for acknowledgment before sending the next diagnostic summary --- crates/collab/src/rpc.rs | 5 ++- crates/project/src/worktree.rs | 57 ++++++++++++++++++---------------- crates/rpc/src/proto.rs | 1 + 3 files changed, 35 insertions(+), 28 deletions(-) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 5e3018160c85b24a51bf04587f880d22008df8e4..db8f25fdb28c56a15a7ea5504951e8a796d1b05e 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -154,7 +154,7 @@ impl Server { .add_request_handler(Server::update_worktree) .add_message_handler(Server::start_language_server) .add_message_handler(Server::update_language_server) - .add_message_handler(Server::update_diagnostic_summary) + .add_request_handler(Server::update_diagnostic_summary) .add_request_handler(Server::forward_project_request::) .add_request_handler(Server::forward_project_request::) .add_request_handler(Server::forward_project_request::) @@ -1127,6 +1127,7 @@ impl Server { async fn update_diagnostic_summary( self: Arc, request: Message, + response: Response, ) -> Result<()> { let guest_connection_ids = self .app_state @@ -1145,6 +1146,8 @@ impl Server { ) }, ); + + response.send(proto::Ack {})?; Ok(()) } diff --git a/crates/project/src/worktree.rs b/crates/project/src/worktree.rs index ddd4a7a6c847998fec8564e147b9f4ff30fa2177..836ac55b661157f8c2f0297567b55143b8b26d2a 100644 --- a/crates/project/src/worktree.rs +++ b/crates/project/src/worktree.rs @@ -166,7 +166,9 @@ enum ScanState { struct ShareState { project_id: u64, snapshots_tx: watch::Sender, - _maintain_remote_snapshot: Option>>, + diagnostic_summaries_tx: mpsc::UnboundedSender<(Arc, DiagnosticSummary)>, + _maintain_remote_snapshot: Task>, + _maintain_remote_diagnostic_summaries: Task<()>, } pub enum Event { @@ -524,18 +526,9 @@ impl LocalWorktree { let updated = !old_summary.is_empty() || !new_summary.is_empty(); if updated { if let Some(share) = self.share.as_ref() { - self.client - .send(proto::UpdateDiagnosticSummary { - project_id: share.project_id, - worktree_id: self.id().to_proto(), - summary: Some(proto::DiagnosticSummary { - path: worktree_path.to_string_lossy().to_string(), - language_server_id: language_server_id as u64, - error_count: new_summary.error_count as u32, - warning_count: new_summary.warning_count as u32, - }), - }) - .log_err(); + let _ = share + .diagnostic_summaries_tx + .unbounded_send((worktree_path.clone(), new_summary)); } } @@ -967,22 +960,10 @@ impl LocalWorktree { let _ = share_tx.send(Ok(())); } else { let (snapshots_tx, mut snapshots_rx) = watch::channel_with(self.snapshot()); - let rpc = self.client.clone(); let worktree_id = cx.model_id() as u64; - for (path, summary) in self.diagnostic_summaries.iter() { - if let Err(e) = rpc.send(proto::UpdateDiagnosticSummary { - project_id, - worktree_id, - summary: Some(summary.to_proto(&path.0)), - }) { - return Task::ready(Err(e)); - } - } - let maintain_remote_snapshot = cx.background().spawn({ - let rpc = rpc; - + let rpc = self.client.clone(); async move { let mut prev_snapshot = match snapshots_rx.recv().await { Some(snapshot) => { @@ -1029,10 +1010,32 @@ impl LocalWorktree { } .log_err() }); + + let (diagnostic_summaries_tx, mut diagnostic_summaries_rx) = mpsc::unbounded(); + for (path, summary) in self.diagnostic_summaries.iter() { + let _ = diagnostic_summaries_tx.unbounded_send((path.0.clone(), summary.clone())); + } + let maintain_remote_diagnostic_summaries = cx.background().spawn({ + let rpc = self.client.clone(); + async move { + while let Some((path, summary)) = diagnostic_summaries_rx.next().await { + rpc.request(proto::UpdateDiagnosticSummary { + project_id, + worktree_id, + summary: Some(summary.to_proto(&path)), + }) + .await + .log_err(); + } + } + }); + self.share = Some(ShareState { project_id, snapshots_tx, - _maintain_remote_snapshot: Some(maintain_remote_snapshot), + diagnostic_summaries_tx, + _maintain_remote_snapshot: maintain_remote_snapshot, + _maintain_remote_diagnostic_summaries: maintain_remote_diagnostic_summaries, }); } diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index 6d9bc9a0aa348af8c1a14f442323fcf06064688e..50f3c57f2a6b3c5bd9bc6798e468df7a541a2f07 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -228,6 +228,7 @@ request_messages!( (ShareProject, ShareProjectResponse), (Test, Test), (UpdateBuffer, Ack), + (UpdateDiagnosticSummary, Ack), (UpdateParticipantLocation, Ack), (UpdateProject, Ack), (UpdateWorktree, Ack), From adf43c87dd2f4f8a76e97ff842d2f8eac82aef4c Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 16 Nov 2022 17:19:06 +0100 Subject: [PATCH 036/109] Batch some of the new queries in `Db` Co-Authored-By: Nathan Sobo --- crates/collab/src/db.rs | 162 ++++++++++++++++++++++++---------------- 1 file changed, 97 insertions(+), 65 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 9163e71aa4dd1e9d5e209830c8c2ace21b82c3e9..d517bdd1df58199a16e43fa0300db5f0024215df 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1354,7 +1354,7 @@ where .bind(room_id) .fetch(&mut *tx); - let mut participants = Vec::new(); + let mut participants = HashMap::default(); let mut pending_participants = Vec::new(); while let Some(participant) = db_participants.next().await { let ( @@ -1381,12 +1381,15 @@ where Default::default(), )), }; - participants.push(proto::Participant { - user_id: user_id.to_proto(), - peer_id: answering_connection_id as u32, - projects: Default::default(), - location: Some(proto::ParticipantLocation { variant: location }), - }); + participants.insert( + answering_connection_id, + proto::Participant { + user_id: user_id.to_proto(), + peer_id: answering_connection_id as u32, + projects: Default::default(), + location: Some(proto::ParticipantLocation { variant: location }), + }, + ); } else { pending_participants.push(proto::PendingParticipant { user_id: user_id.to_proto(), @@ -1397,41 +1400,42 @@ where } drop(db_participants); - for participant in &mut participants { - let mut entries = sqlx::query_as::<_, (ProjectId, String)>( - " - SELECT projects.id, worktrees.root_name - FROM projects - LEFT JOIN worktrees ON projects.id = worktrees.project_id - WHERE room_id = $1 AND host_connection_id = $2 - ", - ) - .bind(room_id) - .bind(participant.peer_id as i32) - .fetch(&mut *tx); + let mut rows = sqlx::query_as::<_, (i32, ProjectId, Option)>( + " + SELECT host_connection_id, projects.id, worktrees.root_name + FROM projects + LEFT JOIN worktrees ON projects.id = worktrees.project_id + WHERE room_id = $1 + ", + ) + .bind(room_id) + .fetch(&mut *tx); - let mut projects = HashMap::default(); - while let Some(entry) = entries.next().await { - let (project_id, worktree_root_name) = entry?; - let participant_project = - projects - .entry(project_id) - .or_insert(proto::ParticipantProject { - id: project_id.to_proto(), - worktree_root_names: Default::default(), - }); - participant_project - .worktree_root_names - .push(worktree_root_name); + while let Some(row) = rows.next().await { + let (connection_id, project_id, worktree_root_name) = row?; + if let Some(participant) = participants.get_mut(&connection_id) { + let project = if let Some(project) = participant + .projects + .iter_mut() + .find(|project| project.id == project_id.to_proto()) + { + project + } else { + participant.projects.push(proto::ParticipantProject { + id: project_id.to_proto(), + worktree_root_names: Default::default(), + }); + participant.projects.last_mut().unwrap() + }; + project.worktree_root_names.extend(worktree_root_name); } - - participant.projects = projects.into_values().collect(); } + Ok(proto::Room { id: room.id.to_proto(), version: room.version as u64, live_kit_room: room.live_kit_room, - participants, + participants: participants.into_values().collect(), pending_participants, }) } @@ -1472,22 +1476,36 @@ where .fetch_one(&mut tx) .await?; - for worktree in worktrees { - sqlx::query( + if !worktrees.is_empty() { + let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len()); + params.pop(); + let query = format!( " - INSERT INTO worktrees (project_id, id, root_name, abs_path, visible, scan_id, is_complete) - VALUES ($1, $2, $3, $4, $5, $6, $7) - ", - ) - .bind(project_id) - .bind(worktree.id as i32) - .bind(&worktree.root_name) - .bind(&worktree.abs_path) - .bind(worktree.visible) - .bind(0) - .bind(false) - .execute(&mut tx) - .await?; + INSERT INTO worktrees ( + project_id, + id, + root_name, + abs_path, + visible, + scan_id, + is_complete + ) + VALUES {params} + " + ); + + let mut query = sqlx::query(&query); + for worktree in worktrees { + query = query + .bind(project_id) + .bind(worktree.id as i32) + .bind(&worktree.root_name) + .bind(&worktree.abs_path) + .bind(worktree.visible) + .bind(0) + .bind(false); + } + query.execute(&mut tx).await?; } sqlx::query( @@ -1535,23 +1553,37 @@ where .fetch_one(&mut tx) .await?; - for worktree in worktrees { - sqlx::query( + if !worktrees.is_empty() { + let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len()); + params.pop(); + let query = format!( " - INSERT INTO worktrees (project_id, id, root_name, abs_path, visible, scan_id, is_complete) - VALUES ($1, $2, $3, $4, $5, $6, $7) + INSERT INTO worktrees ( + project_id, + id, + root_name, + abs_path, + visible, + scan_id, + is_complete + ) + VALUES ${params} ON CONFLICT (project_id, id) DO UPDATE SET root_name = excluded.root_name - ", - ) - .bind(project_id) - .bind(worktree.id as i32) - .bind(&worktree.root_name) - .bind(&worktree.abs_path) - .bind(worktree.visible) - .bind(0) - .bind(false) - .execute(&mut tx) - .await?; + " + ); + + let mut query = sqlx::query(&query); + for worktree in worktrees { + query = query + .bind(project_id) + .bind(worktree.id as i32) + .bind(&worktree.root_name) + .bind(&worktree.abs_path) + .bind(worktree.visible) + .bind(0) + .bind(false) + } + query.execute(&mut tx).await?; } let mut params = "?,".repeat(worktrees.len()); From c1291a093b65f7db4042759557be10c539b02479 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 16 Nov 2022 19:50:57 +0100 Subject: [PATCH 037/109] WIP: Allow subscribing to remote entity before creating a model Co-Authored-By: Nathan Sobo Co-Authored-By: Max Brunsfeld --- crates/client/src/client.rs | 287 +++++++++++++++++++++------------- crates/project/src/project.rs | 30 ++-- 2 files changed, 193 insertions(+), 124 deletions(-) diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index c943b274172c8264ee311270d4575973f945e6cc..bad85384be6b78cce7a0b1f33d48dc471fcff22b 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -17,8 +17,7 @@ use gpui::{ actions, serde_json::{self, Value}, AnyModelHandle, AnyViewHandle, AnyWeakModelHandle, AnyWeakViewHandle, AppContext, - AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, View, ViewContext, - ViewHandle, + AsyncAppContext, Entity, ModelHandle, MutableAppContext, Task, View, ViewContext, ViewHandle, }; use http::HttpClient; use lazy_static::lazy_static; @@ -34,6 +33,7 @@ use std::{ convert::TryFrom, fmt::Write as _, future::Future, + marker::PhantomData, path::PathBuf, sync::{Arc, Weak}, time::{Duration, Instant}, @@ -172,7 +172,7 @@ struct ClientState { entity_id_extractors: HashMap u64>, _reconnect_task: Option>, reconnect_interval: Duration, - entities_by_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakEntityHandle>, + entities_by_type_and_remote_id: HashMap<(TypeId, u64), WeakSubscriber>, models_by_message_type: HashMap, entity_types_by_message_type: HashMap, #[allow(clippy::type_complexity)] @@ -182,7 +182,7 @@ struct ClientState { dyn Send + Sync + Fn( - AnyEntityHandle, + Subscriber, Box, &Arc, AsyncAppContext, @@ -191,12 +191,13 @@ struct ClientState { >, } -enum AnyWeakEntityHandle { +enum WeakSubscriber { Model(AnyWeakModelHandle), View(AnyWeakViewHandle), + Pending(Vec>), } -enum AnyEntityHandle { +enum Subscriber { Model(AnyModelHandle), View(AnyViewHandle), } @@ -254,6 +255,54 @@ impl Drop for Subscription { } } +pub struct PendingEntitySubscription { + client: Arc, + remote_id: u64, + _entity_type: PhantomData, + consumed: bool, +} + +impl PendingEntitySubscription { + pub fn set_model(mut self, model: &ModelHandle, cx: &mut AsyncAppContext) -> Subscription { + self.consumed = true; + let mut state = self.client.state.write(); + let id = (TypeId::of::(), self.remote_id); + let Some(WeakSubscriber::Pending(messages)) = + state.entities_by_type_and_remote_id.remove(&id) + else { + unreachable!() + }; + + state + .entities_by_type_and_remote_id + .insert(id, WeakSubscriber::Model(model.downgrade().into())); + drop(state); + for message in messages { + self.client.handle_message(message, cx); + } + Subscription::Entity { + client: Arc::downgrade(&self.client), + id, + } + } +} + +impl Drop for PendingEntitySubscription { + fn drop(&mut self) { + if !self.consumed { + let mut state = self.client.state.write(); + if let Some(WeakSubscriber::Pending(messages)) = state + .entities_by_type_and_remote_id + .remove(&(TypeId::of::(), self.remote_id)) + { + for message in messages { + log::info!("unhandled message {}", message.payload_type_name()); + } + } + } + } +} + impl Client { pub fn new(http: Arc, cx: &AppContext) -> Arc { Arc::new(Self { @@ -387,26 +436,28 @@ impl Client { self.state .write() .entities_by_type_and_remote_id - .insert(id, AnyWeakEntityHandle::View(cx.weak_handle().into())); + .insert(id, WeakSubscriber::View(cx.weak_handle().into())); Subscription::Entity { client: Arc::downgrade(self), id, } } - pub fn add_model_for_remote_entity( + pub fn subscribe_to_entity( self: &Arc, remote_id: u64, - cx: &mut ModelContext, - ) -> Subscription { + ) -> PendingEntitySubscription { let id = (TypeId::of::(), remote_id); self.state .write() .entities_by_type_and_remote_id - .insert(id, AnyWeakEntityHandle::Model(cx.weak_handle().into())); - Subscription::Entity { - client: Arc::downgrade(self), - id, + .insert(id, WeakSubscriber::Pending(Default::default())); + + PendingEntitySubscription { + client: self.clone(), + remote_id, + consumed: false, + _entity_type: PhantomData, } } @@ -434,7 +485,7 @@ impl Client { let prev_handler = state.message_handlers.insert( message_type_id, Arc::new(move |handle, envelope, client, cx| { - let handle = if let AnyEntityHandle::Model(handle) = handle { + let handle = if let Subscriber::Model(handle) = handle { handle } else { unreachable!(); @@ -488,7 +539,7 @@ impl Client { F: 'static + Future>, { self.add_entity_message_handler::(move |handle, message, client, cx| { - if let AnyEntityHandle::View(handle) = handle { + if let Subscriber::View(handle) = handle { handler(handle.downcast::().unwrap(), message, client, cx) } else { unreachable!(); @@ -507,7 +558,7 @@ impl Client { F: 'static + Future>, { self.add_entity_message_handler::(move |handle, message, client, cx| { - if let AnyEntityHandle::Model(handle) = handle { + if let Subscriber::Model(handle) = handle { handler(handle.downcast::().unwrap(), message, client, cx) } else { unreachable!(); @@ -522,7 +573,7 @@ impl Client { H: 'static + Send + Sync - + Fn(AnyEntityHandle, TypedEnvelope, Arc, AsyncAppContext) -> F, + + Fn(Subscriber, TypedEnvelope, Arc, AsyncAppContext) -> F, F: 'static + Future>, { let model_type_id = TypeId::of::(); @@ -784,94 +835,8 @@ impl Client { let cx = cx.clone(); let this = self.clone(); async move { - let mut message_id = 0_usize; while let Some(message) = incoming.next().await { - let mut state = this.state.write(); - message_id += 1; - let type_name = message.payload_type_name(); - let payload_type_id = message.payload_type_id(); - let sender_id = message.original_sender_id().map(|id| id.0); - - let model = state - .models_by_message_type - .get(&payload_type_id) - .and_then(|model| model.upgrade(&cx)) - .map(AnyEntityHandle::Model) - .or_else(|| { - let entity_type_id = - *state.entity_types_by_message_type.get(&payload_type_id)?; - let entity_id = state - .entity_id_extractors - .get(&message.payload_type_id()) - .map(|extract_entity_id| { - (extract_entity_id)(message.as_ref()) - })?; - - let entity = state - .entities_by_type_and_remote_id - .get(&(entity_type_id, entity_id))?; - if let Some(entity) = entity.upgrade(&cx) { - Some(entity) - } else { - state - .entities_by_type_and_remote_id - .remove(&(entity_type_id, entity_id)); - None - } - }); - - let model = if let Some(model) = model { - model - } else { - log::info!("unhandled message {}", type_name); - continue; - }; - - let handler = state.message_handlers.get(&payload_type_id).cloned(); - // Dropping the state prevents deadlocks if the handler interacts with rpc::Client. - // It also ensures we don't hold the lock while yielding back to the executor, as - // that might cause the executor thread driving this future to block indefinitely. - drop(state); - - if let Some(handler) = handler { - let future = handler(model, message, &this, cx.clone()); - let client_id = this.id; - log::debug!( - "rpc message received. client_id:{}, message_id:{}, sender_id:{:?}, type:{}", - client_id, - message_id, - sender_id, - type_name - ); - cx.foreground() - .spawn(async move { - match future.await { - Ok(()) => { - log::debug!( - "rpc message handled. client_id:{}, message_id:{}, sender_id:{:?}, type:{}", - client_id, - message_id, - sender_id, - type_name - ); - } - Err(error) => { - log::error!( - "error handling message. client_id:{}, message_id:{}, sender_id:{:?}, type:{}, error:{:?}", - client_id, - message_id, - sender_id, - type_name, - error - ); - } - } - }) - .detach(); - } else { - log::info!("unhandled message {}", type_name); - } - + this.handle_message(message, &cx); // Don't starve the main thread when receiving lots of messages at once. smol::future::yield_now().await; } @@ -1218,6 +1183,97 @@ impl Client { self.peer.respond_with_error(receipt, error) } + fn handle_message( + self: &Arc, + message: Box, + cx: &AsyncAppContext, + ) { + let mut state = self.state.write(); + let type_name = message.payload_type_name(); + let payload_type_id = message.payload_type_id(); + let sender_id = message.original_sender_id().map(|id| id.0); + + let mut subscriber = None; + + if let Some(message_model) = state + .models_by_message_type + .get(&payload_type_id) + .and_then(|model| model.upgrade(cx)) + { + subscriber = Some(Subscriber::Model(message_model)); + } else if let Some((extract_entity_id, entity_type_id)) = + state.entity_id_extractors.get(&payload_type_id).zip( + state + .entity_types_by_message_type + .get(&payload_type_id) + .copied(), + ) + { + let entity_id = (extract_entity_id)(message.as_ref()); + + match state + .entities_by_type_and_remote_id + .get_mut(&(entity_type_id, entity_id)) + { + Some(WeakSubscriber::Pending(pending)) => { + pending.push(message); + return; + } + Some(weak_subscriber @ _) => subscriber = weak_subscriber.upgrade(cx), + _ => {} + } + } + + let subscriber = if let Some(subscriber) = subscriber { + subscriber + } else { + log::info!("unhandled message {}", type_name); + return; + }; + + let handler = state.message_handlers.get(&payload_type_id).cloned(); + // Dropping the state prevents deadlocks if the handler interacts with rpc::Client. + // It also ensures we don't hold the lock while yielding back to the executor, as + // that might cause the executor thread driving this future to block indefinitely. + drop(state); + + if let Some(handler) = handler { + let future = handler(subscriber, message, &self, cx.clone()); + let client_id = self.id; + log::debug!( + "rpc message received. client_id:{}, sender_id:{:?}, type:{}", + client_id, + sender_id, + type_name + ); + cx.foreground() + .spawn(async move { + match future.await { + Ok(()) => { + log::debug!( + "rpc message handled. client_id:{}, sender_id:{:?}, type:{}", + client_id, + sender_id, + type_name + ); + } + Err(error) => { + log::error!( + "error handling message. client_id:{}, sender_id:{:?}, type:{}, error:{:?}", + client_id, + sender_id, + type_name, + error + ); + } + } + }) + .detach(); + } else { + log::info!("unhandled message {}", type_name); + } + } + pub fn start_telemetry(&self, db: Db) { self.telemetry.start(db.clone()); } @@ -1231,11 +1287,12 @@ impl Client { } } -impl AnyWeakEntityHandle { - fn upgrade(&self, cx: &AsyncAppContext) -> Option { +impl WeakSubscriber { + fn upgrade(&self, cx: &AsyncAppContext) -> Option { match self { - AnyWeakEntityHandle::Model(handle) => handle.upgrade(cx).map(AnyEntityHandle::Model), - AnyWeakEntityHandle::View(handle) => handle.upgrade(cx).map(AnyEntityHandle::View), + WeakSubscriber::Model(handle) => handle.upgrade(cx).map(Subscriber::Model), + WeakSubscriber::View(handle) => handle.upgrade(cx).map(Subscriber::View), + WeakSubscriber::Pending(_) => None, } } } @@ -1480,11 +1537,17 @@ mod tests { subscription: None, }); - let _subscription1 = model1.update(cx, |_, cx| client.add_model_for_remote_entity(1, cx)); - let _subscription2 = model2.update(cx, |_, cx| client.add_model_for_remote_entity(2, cx)); + let _subscription1 = client + .subscribe_to_entity(1) + .set_model(&model1, &mut cx.to_async()); + let _subscription2 = client + .subscribe_to_entity(2) + .set_model(&model2, &mut cx.to_async()); // Ensure dropping a subscription for the same entity type still allows receiving of // messages for other entity IDs of the same type. - let subscription3 = model3.update(cx, |_, cx| client.add_model_for_remote_entity(3, cx)); + let subscription3 = client + .subscribe_to_entity(3) + .set_model(&model3, &mut cx.to_async()); drop(subscription3); server.send(proto::JoinProject { project_id: 1 }); diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 436b2d92a26e713de120ea084c86062aff5a45e6..503ae8d4b24cc290e539121e50e2803939a9ecc7 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -457,22 +457,23 @@ impl Project { ) -> Result, JoinProjectError> { client.authenticate_and_connect(true, &cx).await?; + let subscription = client.subscribe_to_entity(remote_id); let response = client .request(proto::JoinProject { project_id: remote_id, }) .await?; + let this = cx.add_model(|cx| { + let replica_id = response.replica_id as ReplicaId; - let replica_id = response.replica_id as ReplicaId; - - let mut worktrees = Vec::new(); - for worktree in response.worktrees { - let worktree = cx - .update(|cx| Worktree::remote(remote_id, replica_id, worktree, client.clone(), cx)); - worktrees.push(worktree); - } + let mut worktrees = Vec::new(); + for worktree in response.worktrees { + let worktree = cx.update(|cx| { + Worktree::remote(remote_id, replica_id, worktree, client.clone(), cx) + }); + worktrees.push(worktree); + } - let this = cx.add_model(|cx: &mut ModelContext| { let mut this = Self { worktrees: Vec::new(), loading_buffers: Default::default(), @@ -488,7 +489,7 @@ impl Project { fs, next_entry_id: Default::default(), next_diagnostic_group_id: Default::default(), - client_subscriptions: vec![client.add_model_for_remote_entity(remote_id, cx)], + client_subscriptions: Default::default(), _subscriptions: Default::default(), client: client.clone(), client_state: Some(ProjectClientState::Remote { @@ -541,6 +542,7 @@ impl Project { } this }); + let subscription = subscription.set_model(&this, &mut cx); let user_ids = response .collaborators @@ -558,6 +560,7 @@ impl Project { this.update(&mut cx, |this, _| { this.collaborators = collaborators; + this.client_subscriptions.push(subscription); }); Ok(this) @@ -1035,8 +1038,11 @@ impl Project { }); } - self.client_subscriptions - .push(self.client.add_model_for_remote_entity(project_id, cx)); + self.client_subscriptions.push( + self.client + .subscribe_to_entity(project_id) + .set_model(&cx.handle(), &mut cx.to_async()), + ); let _ = self.metadata_changed(cx); cx.emit(Event::RemoteIdChanged(Some(project_id))); cx.notify(); From bdb521cb6beda3618bdaf868e0ca874d26f726cb Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Wed, 16 Nov 2022 14:24:26 -0700 Subject: [PATCH 038/109] Fix typo in query --- crates/collab/src/db.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index d517bdd1df58199a16e43fa0300db5f0024215df..41cde3bf425778684506ad5eb37d211c80a67761 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1567,7 +1567,7 @@ where scan_id, is_complete ) - VALUES ${params} + VALUES {params} ON CONFLICT (project_id, id) DO UPDATE SET root_name = excluded.root_name " ); From e5f05c9f3b1f5ffa595769c235d192e1a3e5981c Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Wed, 16 Nov 2022 14:24:26 -0700 Subject: [PATCH 039/109] Move leave_project from Store to db module --- crates/collab/src/db.rs | 70 ++++++++++++++++++++++++++++++++-- crates/collab/src/rpc.rs | 27 ++++++------- crates/collab/src/rpc/store.rs | 31 --------------- crates/rpc/src/peer.rs | 2 +- 4 files changed, 82 insertions(+), 48 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 41cde3bf425778684506ad5eb37d211c80a67761..24b0feb2e9eb8e34bc08477bc57791f88a4d23c9 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1209,6 +1209,7 @@ where id: collaborator.project_id, host_user_id: Default::default(), connection_ids: Default::default(), + host_connection_id: Default::default(), }); let collaborator_connection_id = @@ -1219,6 +1220,8 @@ where if collaborator.is_host { left_project.host_user_id = collaborator.user_id; + left_project.host_connection_id = + ConnectionId(collaborator.connection_id as u32); } } } @@ -1474,7 +1477,8 @@ where .bind(user_id) .bind(connection_id.0 as i32) .fetch_one(&mut tx) - .await?; + .await + .unwrap(); if !worktrees.is_empty() { let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len()); @@ -1505,7 +1509,7 @@ where .bind(0) .bind(false); } - query.execute(&mut tx).await?; + query.execute(&mut tx).await.unwrap(); } sqlx::query( @@ -1526,7 +1530,8 @@ where .bind(0) .bind(true) .execute(&mut tx) - .await?; + .await + .unwrap(); let room = self.commit_room_transaction(room_id, tx).await?; Ok((project_id, room)) @@ -2086,6 +2091,64 @@ where .await } + pub async fn leave_project( + &self, + project_id: ProjectId, + connection_id: ConnectionId, + ) -> Result { + self.transact(|mut tx| async move { + let result = sqlx::query( + " + DELETE FROM project_collaborators + WHERE project_id = $1 AND connection_id = $2 + ", + ) + .bind(project_id) + .bind(connection_id.0 as i32) + .execute(&mut tx) + .await?; + + if result.rows_affected() != 1 { + Err(anyhow!("not a collaborator on this project"))?; + } + + let connection_ids = sqlx::query_scalar::<_, i32>( + " + SELECT connection_id + FROM project_collaborators + WHERE project_id = $1 + ", + ) + .bind(project_id) + .fetch_all(&mut tx) + .await? + .into_iter() + .map(|id| ConnectionId(id as u32)) + .collect(); + + let (host_user_id, host_connection_id) = sqlx::query_as::<_, (i32, i32)>( + " + SELECT host_user_id, host_connection_id + FROM projects + WHERE id = $1 + ", + ) + .bind(project_id) + .fetch_one(&mut tx) + .await?; + + tx.commit().await?; + + Ok(LeftProject { + id: project_id, + host_user_id: UserId(host_user_id), + host_connection_id: ConnectionId(host_connection_id as u32), + connection_ids, + }) + }) + .await + } + pub async fn project_collaborators( &self, project_id: ProjectId, @@ -2645,6 +2708,7 @@ struct LanguageServer { pub struct LeftProject { pub id: ProjectId, pub host_user_id: UserId, + pub host_connection_id: ConnectionId, pub connection_ids: Vec, } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index db8f25fdb28c56a15a7ea5504951e8a796d1b05e..c32bdb500894c6eb6e7567385ff010772da98ce0 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1041,8 +1041,11 @@ impl Server { let project_id = ProjectId::from_proto(request.payload.project_id); let project; { - let mut store = self.store().await; - project = store.leave_project(project_id, sender_id)?; + project = self + .app_state + .db + .leave_project(project_id, sender_id) + .await?; tracing::info!( %project_id, host_user_id = %project.host_user_id, @@ -1050,17 +1053,15 @@ impl Server { "leave project" ); - if project.remove_collaborator { - broadcast(sender_id, project.connection_ids, |conn_id| { - self.peer.send( - conn_id, - proto::RemoveProjectCollaborator { - project_id: project_id.to_proto(), - peer_id: sender_id.0, - }, - ) - }); - } + broadcast(sender_id, project.connection_ids, |conn_id| { + self.peer.send( + conn_id, + proto::RemoveProjectCollaborator { + project_id: project_id.to_proto(), + peer_id: sender_id.0, + }, + ) + }); } Ok(()) diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index 57dd726d3facb9a8b3186b7833540c6cfe6f31fc..9c93f0daca250199c6a1751d2aad9964785c40d3 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -251,37 +251,6 @@ impl Store { } } - pub fn leave_project( - &mut self, - project_id: ProjectId, - connection_id: ConnectionId, - ) -> Result { - let project = self - .projects - .get_mut(&project_id) - .ok_or_else(|| anyhow!("no such project"))?; - - // If the connection leaving the project is a collaborator, remove it. - let remove_collaborator = if let Some(guest) = project.guests.remove(&connection_id) { - project.active_replica_ids.remove(&guest.replica_id); - true - } else { - false - }; - - if let Some(connection) = self.connections.get_mut(&connection_id) { - connection.projects.remove(&project_id); - } - - Ok(LeftProject { - id: project.id, - host_connection_id: project.host_connection_id, - host_user_id: project.host.user_id, - connection_ids: project.connection_ids(), - remove_collaborator, - }) - } - #[cfg(test)] pub fn check_invariants(&self) { for (connection_id, connection) in &self.connections { diff --git a/crates/rpc/src/peer.rs b/crates/rpc/src/peer.rs index 4dbade4fec7969164ba80eb13dc3592cfb1c1bda..66ba6a40292d87d13a83943b0e40239ee37b526d 100644 --- a/crates/rpc/src/peer.rs +++ b/crates/rpc/src/peer.rs @@ -24,7 +24,7 @@ use std::{ }; use tracing::instrument; -#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)] +#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Serialize)] pub struct ConnectionId(pub u32); impl fmt::Display for ConnectionId { From 94fe93c6eee43605f837a9944221085b9a0015f4 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Wed, 16 Nov 2022 14:24:26 -0700 Subject: [PATCH 040/109] Move unshare_project to db module --- crates/collab/src/db.rs | 83 ++++++++++++++++++++-------------- crates/collab/src/rpc.rs | 13 ++++-- crates/collab/src/rpc/store.rs | 66 +-------------------------- 3 files changed, 59 insertions(+), 103 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 24b0feb2e9eb8e34bc08477bc57791f88a4d23c9..bc74a8e53046889a7583d5cef0eff593d074a740 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1330,6 +1330,27 @@ where Ok(room) } + async fn get_guest_connection_ids( + &self, + project_id: ProjectId, + tx: &mut sqlx::Transaction<'_, D>, + ) -> Result> { + let mut guest_connection_ids = Vec::new(); + let mut db_guest_connection_ids = sqlx::query_scalar::<_, i32>( + " + SELECT connection_id + FROM project_collaborators + WHERE project_id = $1 AND is_host = FALSE + ", + ) + .bind(project_id) + .fetch(tx); + while let Some(connection_id) = db_guest_connection_ids.next().await { + guest_connection_ids.push(ConnectionId(connection_id? as u32)); + } + Ok(guest_connection_ids) + } + async fn get_room( &self, room_id: RoomId, @@ -1539,6 +1560,31 @@ where .await } + pub async fn unshare_project( + &self, + project_id: ProjectId, + connection_id: ConnectionId, + ) -> Result<(proto::Room, Vec)> { + self.transact(|mut tx| async move { + let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; + let room_id: RoomId = sqlx::query_scalar( + " + DELETE FROM projects + WHERE id = $1 AND host_connection_id = $2 + RETURNING room_id + ", + ) + .bind(project_id) + .bind(connection_id.0 as i32) + .fetch_one(&mut tx) + .await?; + let room = self.commit_room_transaction(room_id, tx).await?; + + Ok((room, guest_connection_ids)) + }) + .await + } + pub async fn update_project( &self, project_id: ProjectId, @@ -1608,23 +1654,9 @@ where } query.execute(&mut tx).await?; - let mut guest_connection_ids = Vec::new(); - { - let mut db_guest_connection_ids = sqlx::query_scalar::<_, i32>( - " - SELECT connection_id - FROM project_collaborators - WHERE project_id = $1 AND is_host = FALSE - ", - ) - .bind(project_id) - .fetch(&mut tx); - while let Some(connection_id) = db_guest_connection_ids.next().await { - guest_connection_ids.push(ConnectionId(connection_id? as u32)); - } - } - + let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; let room = self.commit_room_transaction(room_id, tx).await?; + Ok((room, guest_connection_ids)) }) .await @@ -2108,7 +2140,7 @@ where .execute(&mut tx) .await?; - if result.rows_affected() != 1 { + if result.rows_affected() == 0 { Err(anyhow!("not a collaborator on this project"))?; } @@ -2207,23 +2239,6 @@ where .await } - pub async fn unshare_project(&self, project_id: ProjectId) -> Result<()> { - todo!() - // test_support!(self, { - // sqlx::query( - // " - // UPDATE projects - // SET unregistered = TRUE - // WHERE id = $1 - // ", - // ) - // .bind(project_id) - // .execute(&self.pool) - // .await?; - // Ok(()) - // }) - } - // contacts pub async fn get_contacts(&self, user_id: UserId) -> Result> { diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index c32bdb500894c6eb6e7567385ff010772da98ce0..45330ca8583eb18468abcd791ca0a2d6804cf60e 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -877,14 +877,19 @@ impl Server { message: Message, ) -> Result<()> { let project_id = ProjectId::from_proto(message.payload.project_id); - let mut store = self.store().await; - let (room, project) = store.unshare_project(project_id, message.sender_connection_id)?; + + let (room, guest_connection_ids) = self + .app_state + .db + .unshare_project(project_id, message.sender_connection_id) + .await?; + broadcast( message.sender_connection_id, - project.guest_connection_ids(), + guest_connection_ids, |conn_id| self.peer.send(conn_id, message.payload.clone()), ); - self.room_updated(room); + self.room_updated(&room); Ok(()) } diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index 9c93f0daca250199c6a1751d2aad9964785c40d3..1aa9c709b733dcc317f7274854603b36a8c6bf51 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -1,6 +1,6 @@ use crate::db::{self, ProjectId, UserId}; use anyhow::{anyhow, Result}; -use collections::{btree_map, BTreeMap, BTreeSet, HashMap, HashSet}; +use collections::{BTreeMap, BTreeSet, HashMap, HashSet}; use rpc::{proto, ConnectionId}; use serde::Serialize; use std::path::PathBuf; @@ -72,14 +72,6 @@ pub struct Worktree { pub type ReplicaId = u16; -pub struct LeftProject { - pub id: ProjectId, - pub host_user_id: UserId, - pub host_connection_id: ConnectionId, - pub connection_ids: Vec, - pub remove_collaborator: bool, -} - #[derive(Copy, Clone)] pub struct Metrics { pub connections: usize, @@ -209,48 +201,6 @@ impl Store { &self.rooms } - pub fn unshare_project( - &mut self, - project_id: ProjectId, - connection_id: ConnectionId, - ) -> Result<(&proto::Room, Project)> { - match self.projects.entry(project_id) { - btree_map::Entry::Occupied(e) => { - if e.get().host_connection_id == connection_id { - let project = e.remove(); - - if let Some(host_connection) = self.connections.get_mut(&connection_id) { - host_connection.projects.remove(&project_id); - } - - for guest_connection in project.guests.keys() { - if let Some(connection) = self.connections.get_mut(guest_connection) { - connection.projects.remove(&project_id); - } - } - - let room = self - .rooms - .get_mut(&project.room_id) - .ok_or_else(|| anyhow!("no such room"))?; - let participant = room - .participants - .iter_mut() - .find(|participant| participant.peer_id == connection_id.0) - .ok_or_else(|| anyhow!("no such room"))?; - participant - .projects - .retain(|project| project.id != project_id.to_proto()); - - Ok((room, project)) - } else { - Err(anyhow!("no such project"))? - } - } - btree_map::Entry::Vacant(_) => Err(anyhow!("no such project"))?, - } - } - #[cfg(test)] pub fn check_invariants(&self) { for (connection_id, connection) in &self.connections { @@ -373,17 +323,3 @@ impl Store { } } } - -impl Project { - pub fn guest_connection_ids(&self) -> Vec { - self.guests.keys().copied().collect() - } - - pub fn connection_ids(&self) -> Vec { - self.guests - .keys() - .copied() - .chain(Some(self.host_connection_id)) - .collect() - } -} From 9eee22ff0ab6856a195568409e53b6d91a48f094 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Wed, 16 Nov 2022 14:24:26 -0700 Subject: [PATCH 041/109] Fix column name in query --- crates/collab/src/db.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index bc74a8e53046889a7583d5cef0eff593d074a740..6741afab7ebb7403637be748d498fee57a0d7d65 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1760,7 +1760,7 @@ where let query = format!( " DELETE FROM worktree_entries - WHERE project_id = ? AND worktree_id = ? AND entry_id IN ({params}) + WHERE project_id = ? AND worktree_id = ? AND id IN ({params}) " ); From 532a5992394d96dfaf9bb8921aab8036368a23b6 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 17 Nov 2022 11:38:00 +0100 Subject: [PATCH 042/109] Use `Db::get_guest_connection_ids` in other db methods --- crates/collab/src/db.rs | 57 +++++------------------------------------ 1 file changed, 6 insertions(+), 51 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 6741afab7ebb7403637be748d498fee57a0d7d65..9485d1aae0d201de5843950cba31ce99da84fd19 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1771,24 +1771,9 @@ where query.execute(&mut tx).await?; } - let connection_ids = sqlx::query_scalar::<_, i32>( - " - SELECT connection_id - FROM project_collaborators - WHERE project_id = $1 AND connection_id != $2 - ", - ) - .bind(project_id) - .bind(connection_id.0 as i32) - .fetch_all(&mut tx) - .await?; - + let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; tx.commit().await?; - - Ok(connection_ids - .into_iter() - .map(|connection_id| ConnectionId(connection_id as u32)) - .collect()) + Ok(connection_ids) }) .await } @@ -1846,24 +1831,9 @@ where .execute(&mut tx) .await?; - let connection_ids = sqlx::query_scalar::<_, i32>( - " - SELECT connection_id - FROM project_collaborators - WHERE project_id = $1 AND connection_id != $2 - ", - ) - .bind(project_id) - .bind(connection_id.0 as i32) - .fetch_all(&mut tx) - .await?; - + let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; tx.commit().await?; - - Ok(connection_ids - .into_iter() - .map(|connection_id| ConnectionId(connection_id as u32)) - .collect()) + Ok(connection_ids) }) .await } @@ -1908,24 +1878,9 @@ where .execute(&mut tx) .await?; - let connection_ids = sqlx::query_scalar::<_, i32>( - " - SELECT connection_id - FROM project_collaborators - WHERE project_id = $1 AND connection_id != $2 - ", - ) - .bind(project_id) - .bind(connection_id.0 as i32) - .fetch_all(&mut tx) - .await?; - + let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; tx.commit().await?; - - Ok(connection_ids - .into_iter() - .map(|connection_id| ConnectionId(connection_id as u32)) - .collect()) + Ok(connection_ids) }) .await } From 71eeeedc05f7ed6978f2ebfc6f169a7bc9cc8907 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 17 Nov 2022 12:21:51 +0100 Subject: [PATCH 043/109] Don't replace newer diagnostics with older ones --- .../20221109000000_test_schema.sql | 1 + .../20221111092550_reconnection_support.sql | 1 + crates/collab/src/db.rs | 11 ++++++++--- crates/collab/src/integration_tests.rs | 12 ++++++++---- crates/project/src/project.rs | 4 ++++ crates/project/src/worktree.rs | 16 ++++++++++------ crates/rpc/proto/zed.proto | 1 + 7 files changed, 33 insertions(+), 13 deletions(-) diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 66925fddd55fba36464eef2fab7b4f30af75362f..bb216eb32d2e8beef1fda0ff55a4ac94a7cc7f4b 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -82,6 +82,7 @@ CREATE TABLE "worktree_diagnostic_summaries" ( "language_server_id" INTEGER NOT NULL, "error_count" INTEGER NOT NULL, "warning_count" INTEGER NOT NULL, + "version" INTEGER NOT NULL, PRIMARY KEY(project_id, worktree_id, path), FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE ); diff --git a/crates/collab/migrations/20221111092550_reconnection_support.sql b/crates/collab/migrations/20221111092550_reconnection_support.sql index 4f4ad6aede8b2160c19d0899b505a9ca1c48b3aa..5696dc4a4427bcb312717eafa1715476ea5116a7 100644 --- a/crates/collab/migrations/20221111092550_reconnection_support.sql +++ b/crates/collab/migrations/20221111092550_reconnection_support.sql @@ -44,6 +44,7 @@ CREATE TABLE "worktree_diagnostic_summaries" ( "language_server_id" INTEGER NOT NULL, "error_count" INTEGER NOT NULL, "warning_count" INTEGER NOT NULL, + "version" INTEGER NOT NULL, PRIMARY KEY(project_id, worktree_id, path), FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE ); diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 9485d1aae0d201de5843950cba31ce99da84fd19..2823b49255a8d0ac77d63aa04183d4b67aa6e83a 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1813,13 +1813,15 @@ where path, language_server_id, error_count, - warning_count + warning_count, + version ) - VALUES ($1, $2, $3, $4, $5, $6) + VALUES ($1, $2, $3, $4, $5, $6, $7) ON CONFLICT (project_id, worktree_id, path) DO UPDATE SET language_server_id = excluded.language_server_id, error_count = excluded.error_count, - warning_count = excluded.warning_count + warning_count = excluded.warning_count, + version = excluded.version ", ) .bind(project_id) @@ -1828,6 +1830,7 @@ where .bind(summary.language_server_id as i64) .bind(summary.error_count as i32) .bind(summary.warning_count as i32) + .bind(summary.version as i32) .execute(&mut tx) .await?; @@ -2042,6 +2045,7 @@ where language_server_id: summary.language_server_id as u64, error_count: summary.error_count as u32, warning_count: summary.warning_count as u32, + version: summary.version as u32, }); } } @@ -2666,6 +2670,7 @@ struct WorktreeDiagnosticSummary { language_server_id: i64, error_count: i32, warning_count: i32, + version: i32, } id_type!(LanguageServerId); diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index 1236af42cb05af4b544f74166284d34aa3e44739..d730b5d4e777640b3d3b643a31cb3b1225b195b6 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -2412,9 +2412,10 @@ async fn test_collaborating_with_diagnostics( path: Arc::from(Path::new("a.rs")), }, DiagnosticSummary { + language_server_id: 0, error_count: 1, warning_count: 0, - ..Default::default() + version: 2, }, )] ) @@ -2444,9 +2445,10 @@ async fn test_collaborating_with_diagnostics( path: Arc::from(Path::new("a.rs")), }, DiagnosticSummary { + language_server_id: 0, error_count: 1, warning_count: 0, - ..Default::default() + version: 2, }, )] ); @@ -2484,9 +2486,10 @@ async fn test_collaborating_with_diagnostics( path: Arc::from(Path::new("a.rs")), }, DiagnosticSummary { + language_server_id: 0, error_count: 1, warning_count: 1, - ..Default::default() + version: 3, }, )] ); @@ -2500,9 +2503,10 @@ async fn test_collaborating_with_diagnostics( path: Arc::from(Path::new("a.rs")), }, DiagnosticSummary { + language_server_id: 0, error_count: 1, warning_count: 1, - ..Default::default() + version: 3, }, )] ); diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 503ae8d4b24cc290e539121e50e2803939a9ecc7..9d7323f989254cf9ef3728f07338e953b8b7397b 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -223,6 +223,7 @@ pub struct DiagnosticSummary { pub language_server_id: usize, pub error_count: usize, pub warning_count: usize, + pub version: usize, } #[derive(Debug, Clone)] @@ -293,12 +294,14 @@ pub struct ProjectTransaction(pub HashMap, language::Transac impl DiagnosticSummary { fn new<'a, T: 'a>( language_server_id: usize, + version: usize, diagnostics: impl IntoIterator>, ) -> Self { let mut this = Self { language_server_id, error_count: 0, warning_count: 0, + version, }; for entry in diagnostics { @@ -324,6 +327,7 @@ impl DiagnosticSummary { language_server_id: self.language_server_id as u64, error_count: self.error_count as u32, warning_count: self.warning_count as u32, + version: self.version as u32, } } } diff --git a/crates/project/src/worktree.rs b/crates/project/src/worktree.rs index 836ac55b661157f8c2f0297567b55143b8b26d2a..04e77cf09af3a395896d5bba9d7ef2fb54ba1ccf 100644 --- a/crates/project/src/worktree.rs +++ b/crates/project/src/worktree.rs @@ -366,6 +366,7 @@ impl Worktree { Worktree::Remote(worktree) => &worktree.diagnostic_summaries, } .iter() + .filter(|(_, summary)| !summary.is_empty()) .map(|(path, summary)| (path.0.clone(), *summary)) } @@ -516,7 +517,8 @@ impl LocalWorktree { .diagnostic_summaries .remove(&PathKey(worktree_path.clone())) .unwrap_or_default(); - let new_summary = DiagnosticSummary::new(language_server_id, &diagnostics); + let new_summary = + DiagnosticSummary::new(language_server_id, old_summary.version + 1, &diagnostics); if !new_summary.is_empty() { self.diagnostic_summaries .insert(PathKey(worktree_path.clone()), new_summary); @@ -1106,15 +1108,17 @@ impl RemoteWorktree { path: Arc, summary: &proto::DiagnosticSummary, ) { - let summary = DiagnosticSummary { + let old_summary = self.diagnostic_summaries.get(&PathKey(path.clone())); + let new_summary = DiagnosticSummary { language_server_id: summary.language_server_id as usize, error_count: summary.error_count as usize, warning_count: summary.warning_count as usize, + version: summary.version as usize, }; - if summary.is_empty() { - self.diagnostic_summaries.remove(&PathKey(path)); - } else { - self.diagnostic_summaries.insert(PathKey(path), summary); + if old_summary.map_or(true, |old_summary| { + new_summary.version >= old_summary.version + }) { + self.diagnostic_summaries.insert(PathKey(path), new_summary); } } diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index 30c1c89e8f8b393f96e13c96ad9ea42e14ff7a7e..b6d4b83b3b8e65c1c3c1a20ce7dc40c4452d31cb 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -652,6 +652,7 @@ message DiagnosticSummary { uint64 language_server_id = 2; uint32 error_count = 3; uint32 warning_count = 4; + uint32 version = 5; } message UpdateLanguageServer { From 3b34d858b5b5143a0549179a502f6a25e8e905ce Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 17 Nov 2022 13:33:26 +0100 Subject: [PATCH 044/109] Remove unwrap from `Server::share_project` --- crates/collab/src/rpc.rs | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 45330ca8583eb18468abcd791ca0a2d6804cf60e..70419623ef1ee18a49401953ec9cf4d2b47e2bb2 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -862,8 +862,7 @@ impl Server { request.sender_connection_id, &request.payload.worktrees, ) - .await - .unwrap(); + .await?; response.send(proto::ShareProjectResponse { project_id: project_id.to_proto(), })?; From fe93263ad450a1460ccb5edfde1ca868d132e8c6 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 17 Nov 2022 14:12:00 +0100 Subject: [PATCH 045/109] Wait for previous `UpdateFollowers` message ack before sending new ones --- crates/collab/src/integration_tests.rs | 82 +++++++++++++++++--------- crates/collab/src/rpc.rs | 4 +- crates/rpc/src/proto.rs | 1 + crates/workspace/src/workspace.rs | 76 +++++++++++++++--------- 4 files changed, 106 insertions(+), 57 deletions(-) diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index d730b5d4e777640b3d3b643a31cb3b1225b195b6..511851002443aada5f27a0c7f7508c5bd560e0b5 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -4672,7 +4672,7 @@ async fn test_following( cx_a: &mut TestAppContext, cx_b: &mut TestAppContext, ) { - cx_a.foreground().forbid_parking(); + deterministic.forbid_parking(); cx_a.update(editor::init); cx_b.update(editor::init); @@ -4791,11 +4791,14 @@ async fn test_following( workspace_a.update(cx_a, |workspace, cx| { workspace.activate_item(&editor_a1, cx) }); - workspace_b - .condition(cx_b, |workspace, cx| { - workspace.active_item(cx).unwrap().id() == editor_b1.id() - }) - .await; + deterministic.run_until_parked(); + assert_eq!( + workspace_b.read_with(cx_b, |workspace, cx| workspace + .active_item(cx) + .unwrap() + .id()), + editor_b1.id() + ); // When client A navigates back and forth, client B does so as well. workspace_a @@ -4803,49 +4806,74 @@ async fn test_following( workspace::Pane::go_back(workspace, None, cx) }) .await; - workspace_b - .condition(cx_b, |workspace, cx| { - workspace.active_item(cx).unwrap().id() == editor_b2.id() - }) - .await; + deterministic.run_until_parked(); + assert_eq!( + workspace_b.read_with(cx_b, |workspace, cx| workspace + .active_item(cx) + .unwrap() + .id()), + editor_b2.id() + ); workspace_a .update(cx_a, |workspace, cx| { workspace::Pane::go_forward(workspace, None, cx) }) .await; - workspace_b - .condition(cx_b, |workspace, cx| { - workspace.active_item(cx).unwrap().id() == editor_b1.id() + workspace_a + .update(cx_a, |workspace, cx| { + workspace::Pane::go_back(workspace, None, cx) + }) + .await; + workspace_a + .update(cx_a, |workspace, cx| { + workspace::Pane::go_forward(workspace, None, cx) }) .await; + deterministic.run_until_parked(); + assert_eq!( + workspace_b.read_with(cx_b, |workspace, cx| workspace + .active_item(cx) + .unwrap() + .id()), + editor_b1.id() + ); // Changes to client A's editor are reflected on client B. editor_a1.update(cx_a, |editor, cx| { editor.change_selections(None, cx, |s| s.select_ranges([1..1, 2..2])); }); - editor_b1 - .condition(cx_b, |editor, cx| { - editor.selections.ranges(cx) == vec![1..1, 2..2] - }) - .await; + deterministic.run_until_parked(); + assert_eq!( + editor_b1.read_with(cx_b, |editor, cx| editor.selections.ranges(cx)), + vec![1..1, 2..2] + ); editor_a1.update(cx_a, |editor, cx| editor.set_text("TWO", cx)); - editor_b1 - .condition(cx_b, |editor, cx| editor.text(cx) == "TWO") - .await; + deterministic.run_until_parked(); + assert_eq!( + editor_b1.read_with(cx_b, |editor, cx| editor.text(cx)), + "TWO" + ); editor_a1.update(cx_a, |editor, cx| { editor.change_selections(None, cx, |s| s.select_ranges([3..3])); editor.set_scroll_position(vec2f(0., 100.), cx); }); - editor_b1 - .condition(cx_b, |editor, cx| { - editor.selections.ranges(cx) == vec![3..3] - }) - .await; + deterministic.run_until_parked(); + assert_eq!( + editor_b1.read_with(cx_b, |editor, cx| editor.selections.ranges(cx)), + vec![3..3] + ); // After unfollowing, client B stops receiving updates from client A. + assert_eq!( + workspace_b.read_with(cx_b, |workspace, cx| workspace + .active_item(cx) + .unwrap() + .id()), + editor_b1.id() + ); workspace_b.update(cx_b, |workspace, cx| { workspace.unfollow(&workspace.active_pane().clone(), cx) }); diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 70419623ef1ee18a49401953ec9cf4d2b47e2bb2..a07a8b37c870a3d070b840dc59c96a23a27c2087 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -192,7 +192,7 @@ impl Server { .add_request_handler(Server::respond_to_contact_request) .add_request_handler(Server::follow) .add_message_handler(Server::unfollow) - .add_message_handler(Server::update_followers) + .add_request_handler(Server::update_followers) .add_message_handler(Server::update_diff_base) .add_request_handler(Server::get_private_user_info); @@ -1437,6 +1437,7 @@ impl Server { async fn update_followers( self: Arc, request: Message, + response: Response, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); let project_connection_ids = self @@ -1464,6 +1465,7 @@ impl Server { )?; } } + response.send(proto::Ack {})?; Ok(()) } diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index 50f3c57f2a6b3c5bd9bc6798e468df7a541a2f07..8a59818fa3d2bb95423465014456901daa945897 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -229,6 +229,7 @@ request_messages!( (Test, Test), (UpdateBuffer, Ack), (UpdateDiagnosticSummary, Ack), + (UpdateFollowers, Ack), (UpdateParticipantLocation, Ack), (UpdateProject, Ack), (UpdateWorktree, Ack), diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index 2296741ed3c7f31768c2bd5857a463e18179c4fe..5f14427feea53cc1b19e2674eabf374b9d4254be 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -18,7 +18,10 @@ use collections::{hash_map, HashMap, HashSet}; use dock::{DefaultItemFactory, Dock, ToggleDockButton}; use drag_and_drop::DragAndDrop; use fs::{self, Fs}; -use futures::{channel::oneshot, FutureExt, StreamExt}; +use futures::{ + channel::{mpsc, oneshot}, + FutureExt, StreamExt, +}; use gpui::{ actions, elements::*, @@ -711,14 +714,13 @@ impl ItemHandle for ViewHandle { if let Some(followed_item) = self.to_followable_item_handle(cx) { if let Some(message) = followed_item.to_state_proto(cx) { - workspace.update_followers( - proto::update_followers::Variant::CreateView(proto::View { + workspace.update_followers(proto::update_followers::Variant::CreateView( + proto::View { id: followed_item.id() as u64, variant: Some(message), leader_id: workspace.leader_for_pane(&pane).map(|id| id.0), - }), - cx, - ); + }, + )); } } @@ -762,7 +764,7 @@ impl ItemHandle for ViewHandle { cx.after_window_update({ let pending_update = pending_update.clone(); let pending_update_scheduled = pending_update_scheduled.clone(); - move |this, cx| { + move |this, _| { pending_update_scheduled.store(false, SeqCst); this.update_followers( proto::update_followers::Variant::UpdateView( @@ -772,7 +774,6 @@ impl ItemHandle for ViewHandle { leader_id: leader_id.map(|id| id.0), }, ), - cx, ); } }); @@ -1081,9 +1082,11 @@ pub struct Workspace { leader_state: LeaderState, follower_states_by_leader: FollowerStatesByLeader, last_leaders_by_pane: HashMap, PeerId>, + follower_updates: mpsc::UnboundedSender, window_edited: bool, active_call: Option<(ModelHandle, Vec)>, _observe_current_user: Task<()>, + _update_followers: Task>, } #[derive(Default)] @@ -1166,6 +1169,34 @@ impl Workspace { } }); + let (follower_updates_tx, mut follower_updates_rx) = mpsc::unbounded(); + let _update_followers = cx.spawn_weak(|this, cx| async move { + while let Some(update) = follower_updates_rx.next().await { + let this = this.upgrade(&cx)?; + let update_followers = this.read_with(&cx, |this, cx| { + if let Some(project_id) = this.project.read(cx).remote_id() { + if this.leader_state.followers.is_empty() { + None + } else { + Some(this.client.request(proto::UpdateFollowers { + project_id, + follower_ids: + this.leader_state.followers.iter().map(|f| f.0).collect(), + variant: Some(update), + })) + } + } else { + None + } + }); + + if let Some(update_followers) = update_followers { + update_followers.await.log_err(); + } + } + None + }); + let handle = cx.handle(); let weak_handle = cx.weak_handle(); @@ -1224,10 +1255,12 @@ impl Workspace { project, leader_state: Default::default(), follower_states_by_leader: Default::default(), + follower_updates: follower_updates_tx, last_leaders_by_pane: Default::default(), window_edited: false, active_call, _observe_current_user, + _update_followers, }; this.project_remote_id_changed(this.project.read(cx).remote_id(), cx); cx.defer(|this, cx| this.update_window_title(cx)); @@ -1967,13 +2000,12 @@ impl Workspace { cx.notify(); } - self.update_followers( - proto::update_followers::Variant::UpdateActiveView(proto::UpdateActiveView { + self.update_followers(proto::update_followers::Variant::UpdateActiveView( + proto::UpdateActiveView { id: self.active_item(cx).map(|item| item.id() as u64), leader_id: self.leader_for_pane(&pane).map(|id| id.0), - }), - cx, - ); + }, + )); } fn handle_pane_event( @@ -2594,22 +2626,8 @@ impl Workspace { Ok(()) } - fn update_followers( - &self, - update: proto::update_followers::Variant, - cx: &AppContext, - ) -> Option<()> { - let project_id = self.project.read(cx).remote_id()?; - if !self.leader_state.followers.is_empty() { - self.client - .send(proto::UpdateFollowers { - project_id, - follower_ids: self.leader_state.followers.iter().map(|f| f.0).collect(), - variant: Some(update), - }) - .log_err(); - } - None + fn update_followers(&self, update: proto::update_followers::Variant) { + let _ = self.follower_updates.unbounded_send(update); } pub fn leader_for_pane(&self, pane: &ViewHandle) -> Option { From 6415809b610e4bfb158ab6ea257929fb410bbb16 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 17 Nov 2022 15:34:12 +0100 Subject: [PATCH 046/109] Fix errors in Postgres schema --- .../collab/migrations/20221111092550_reconnection_support.sql | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/collab/migrations/20221111092550_reconnection_support.sql b/crates/collab/migrations/20221111092550_reconnection_support.sql index 5696dc4a4427bcb312717eafa1715476ea5116a7..50a4a7154b5c433ea865d1b49bae423b64f044bf 100644 --- a/crates/collab/migrations/20221111092550_reconnection_support.sql +++ b/crates/collab/migrations/20221111092550_reconnection_support.sql @@ -59,7 +59,7 @@ CREATE TABLE "language_servers" ( CREATE INDEX "index_language_servers_on_project_id" ON "language_servers" ("project_id"); CREATE TABLE "project_collaborators" ( - "id" INTEGER PRIMARY KEY, + "id" SERIAL PRIMARY KEY, "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, "connection_id" INTEGER NOT NULL, "user_id" INTEGER NOT NULL, @@ -70,7 +70,7 @@ CREATE INDEX "index_project_collaborators_on_project_id" ON "project_collaborato CREATE UNIQUE INDEX "index_project_collaborators_on_project_id_and_replica_id" ON "project_collaborators" ("project_id", "replica_id"); CREATE TABLE "room_participants" ( - "id" INTEGER PRIMARY KEY, + "id" SERIAL PRIMARY KEY, "room_id" INTEGER NOT NULL REFERENCES rooms (id), "user_id" INTEGER NOT NULL REFERENCES users (id), "answering_connection_id" INTEGER, From 0f4598a2435f34f15ed739a7dd75419eff05d4c5 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 17 Nov 2022 15:34:35 +0100 Subject: [PATCH 047/109] Fix seed script --- crates/collab/src/bin/seed.rs | 58 ++--------------------------------- 1 file changed, 2 insertions(+), 56 deletions(-) diff --git a/crates/collab/src/bin/seed.rs b/crates/collab/src/bin/seed.rs index cabea7d013776d4f3cb248d1b0c8985a0f3090a2..3b635540b315bfbebe6058f9457e65237a0f1e3b 100644 --- a/crates/collab/src/bin/seed.rs +++ b/crates/collab/src/bin/seed.rs @@ -1,9 +1,7 @@ use collab::{Error, Result}; -use db::{Db, PostgresDb, UserId}; -use rand::prelude::*; +use db::{DefaultDb, UserId}; use serde::{de::DeserializeOwned, Deserialize}; use std::fmt::Write; -use time::{Duration, OffsetDateTime}; #[allow(unused)] #[path = "../db.rs"] @@ -18,9 +16,8 @@ struct GitHubUser { #[tokio::main] async fn main() { - let mut rng = StdRng::from_entropy(); let database_url = std::env::var("DATABASE_URL").expect("missing DATABASE_URL env var"); - let db = PostgresDb::new(&database_url, 5) + let db = DefaultDb::new(&database_url, 5) .await .expect("failed to connect to postgres database"); let github_token = std::env::var("GITHUB_TOKEN").expect("missing GITHUB_TOKEN env var"); @@ -104,57 +101,6 @@ async fn main() { ); } } - - let zed_org_id = if let Some(org) = db - .find_org_by_slug("zed") - .await - .expect("failed to fetch org") - { - org.id - } else { - db.create_org("Zed", "zed") - .await - .expect("failed to insert org") - }; - - let general_channel_id = if let Some(channel) = db - .get_org_channels(zed_org_id) - .await - .expect("failed to fetch channels") - .iter() - .find(|c| c.name == "General") - { - channel.id - } else { - let channel_id = db - .create_org_channel(zed_org_id, "General") - .await - .expect("failed to insert channel"); - - let now = OffsetDateTime::now_utc(); - let max_seconds = Duration::days(100).as_seconds_f64(); - let mut timestamps = (0..1000) - .map(|_| now - Duration::seconds_f64(rng.gen_range(0_f64..=max_seconds))) - .collect::>(); - timestamps.sort(); - for timestamp in timestamps { - let sender_id = *zed_user_ids.choose(&mut rng).unwrap(); - let body = lipsum::lipsum_words(rng.gen_range(1..=50)); - db.create_channel_message(channel_id, sender_id, &body, timestamp, rng.gen()) - .await - .expect("failed to insert message"); - } - channel_id - }; - - for user_id in zed_user_ids { - db.add_org_member(zed_org_id, user_id, true) - .await - .expect("failed to insert org membership"); - db.add_channel_member(general_channel_id, user_id, true) - .await - .expect("failed to insert channel membership"); - } } async fn fetch_github( From 7dae21cb36f3dbf6182b0db0f9752567438c95d5 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 17 Nov 2022 15:35:03 +0100 Subject: [PATCH 048/109] :art: --- crates/collab/src/db.rs | 14 +++++--------- crates/collab/src/rpc.rs | 2 +- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 2823b49255a8d0ac77d63aa04183d4b67aa6e83a..55c71ea92e0c53290bb1bb1de9f7746864089e1d 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -49,6 +49,7 @@ impl BeginTransaction for Db { } // In Sqlite, transactions are inherently serializable. +#[cfg(test)] impl BeginTransaction for Db { type Database = sqlx::Sqlite; @@ -1141,10 +1142,7 @@ where .await } - pub async fn leave_room_for_connection( - &self, - connection_id: ConnectionId, - ) -> Result> { + pub async fn leave_room(&self, connection_id: ConnectionId) -> Result> { self.transact(|mut tx| async move { // Leave room. let room_id = sqlx::query_scalar::<_, RoomId>( @@ -1498,8 +1496,7 @@ where .bind(user_id) .bind(connection_id.0 as i32) .fetch_one(&mut tx) - .await - .unwrap(); + .await?; if !worktrees.is_empty() { let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len()); @@ -1530,7 +1527,7 @@ where .bind(0) .bind(false); } - query.execute(&mut tx).await.unwrap(); + query.execute(&mut tx).await?; } sqlx::query( @@ -1551,8 +1548,7 @@ where .bind(0) .bind(true) .execute(&mut tx) - .await - .unwrap(); + .await?; let room = self.commit_room_transaction(room_id, tx).await?; Ok((project_id, room)) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index a07a8b37c870a3d070b840dc59c96a23a27c2087..9e0335ef1b16b432b8a4fdefcfe4e909bbd954ac 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -629,7 +629,7 @@ impl Server { ) -> Result<()> { let mut contacts_to_update = HashSet::default(); - let Some(left_room) = self.app_state.db.leave_room_for_connection(leaving_connection_id).await? else { + let Some(left_room) = self.app_state.db.leave_room(leaving_connection_id).await? else { return Err(anyhow!("no room to leave"))?; }; contacts_to_update.insert(leaving_user_id); From 8621c88a3ce088808b64fe03a4771dac7c62de7a Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 17 Nov 2022 16:56:43 +0100 Subject: [PATCH 049/109] Use int8 for `scan_id` and `inode` in Postgres --- .../collab/migrations/20221111092550_reconnection_support.sql | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/collab/migrations/20221111092550_reconnection_support.sql b/crates/collab/migrations/20221111092550_reconnection_support.sql index 50a4a7154b5c433ea865d1b49bae423b64f044bf..de29f0c878ca0b710f5796a70b49ce0720080418 100644 --- a/crates/collab/migrations/20221111092550_reconnection_support.sql +++ b/crates/collab/migrations/20221111092550_reconnection_support.sql @@ -15,7 +15,7 @@ CREATE TABLE "worktrees" ( "root_name" VARCHAR NOT NULL, "abs_path" VARCHAR NOT NULL, "visible" BOOL NOT NULL, - "scan_id" INTEGER NOT NULL, + "scan_id" INT8 NOT NULL, "is_complete" BOOL NOT NULL, PRIMARY KEY(project_id, id) ); @@ -27,7 +27,7 @@ CREATE TABLE "worktree_entries" ( "id" INTEGER NOT NULL, "is_dir" BOOL NOT NULL, "path" VARCHAR NOT NULL, - "inode" INTEGER NOT NULL, + "inode" INT8 NOT NULL, "mtime_seconds" INTEGER NOT NULL, "mtime_nanos" INTEGER NOT NULL, "is_symlink" BOOL NOT NULL, From e7e45be6e141ac50db80cf66d1445afb8163d681 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 17 Nov 2022 16:57:32 +0100 Subject: [PATCH 050/109] Revert "Wait for previous `UpdateFollowers` message ack before sending new ones" This reverts commit fe93263ad450a1460ccb5edfde1ca868d132e8c6. --- crates/collab/src/integration_tests.rs | 82 +++++++++----------------- crates/collab/src/rpc.rs | 4 +- crates/rpc/src/proto.rs | 1 - crates/workspace/src/workspace.rs | 76 +++++++++--------------- 4 files changed, 57 insertions(+), 106 deletions(-) diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index 511851002443aada5f27a0c7f7508c5bd560e0b5..d730b5d4e777640b3d3b643a31cb3b1225b195b6 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -4672,7 +4672,7 @@ async fn test_following( cx_a: &mut TestAppContext, cx_b: &mut TestAppContext, ) { - deterministic.forbid_parking(); + cx_a.foreground().forbid_parking(); cx_a.update(editor::init); cx_b.update(editor::init); @@ -4791,14 +4791,11 @@ async fn test_following( workspace_a.update(cx_a, |workspace, cx| { workspace.activate_item(&editor_a1, cx) }); - deterministic.run_until_parked(); - assert_eq!( - workspace_b.read_with(cx_b, |workspace, cx| workspace - .active_item(cx) - .unwrap() - .id()), - editor_b1.id() - ); + workspace_b + .condition(cx_b, |workspace, cx| { + workspace.active_item(cx).unwrap().id() == editor_b1.id() + }) + .await; // When client A navigates back and forth, client B does so as well. workspace_a @@ -4806,74 +4803,49 @@ async fn test_following( workspace::Pane::go_back(workspace, None, cx) }) .await; - deterministic.run_until_parked(); - assert_eq!( - workspace_b.read_with(cx_b, |workspace, cx| workspace - .active_item(cx) - .unwrap() - .id()), - editor_b2.id() - ); - - workspace_a - .update(cx_a, |workspace, cx| { - workspace::Pane::go_forward(workspace, None, cx) + workspace_b + .condition(cx_b, |workspace, cx| { + workspace.active_item(cx).unwrap().id() == editor_b2.id() }) .await; + workspace_a .update(cx_a, |workspace, cx| { - workspace::Pane::go_back(workspace, None, cx) + workspace::Pane::go_forward(workspace, None, cx) }) .await; - workspace_a - .update(cx_a, |workspace, cx| { - workspace::Pane::go_forward(workspace, None, cx) + workspace_b + .condition(cx_b, |workspace, cx| { + workspace.active_item(cx).unwrap().id() == editor_b1.id() }) .await; - deterministic.run_until_parked(); - assert_eq!( - workspace_b.read_with(cx_b, |workspace, cx| workspace - .active_item(cx) - .unwrap() - .id()), - editor_b1.id() - ); // Changes to client A's editor are reflected on client B. editor_a1.update(cx_a, |editor, cx| { editor.change_selections(None, cx, |s| s.select_ranges([1..1, 2..2])); }); - deterministic.run_until_parked(); - assert_eq!( - editor_b1.read_with(cx_b, |editor, cx| editor.selections.ranges(cx)), - vec![1..1, 2..2] - ); + editor_b1 + .condition(cx_b, |editor, cx| { + editor.selections.ranges(cx) == vec![1..1, 2..2] + }) + .await; editor_a1.update(cx_a, |editor, cx| editor.set_text("TWO", cx)); - deterministic.run_until_parked(); - assert_eq!( - editor_b1.read_with(cx_b, |editor, cx| editor.text(cx)), - "TWO" - ); + editor_b1 + .condition(cx_b, |editor, cx| editor.text(cx) == "TWO") + .await; editor_a1.update(cx_a, |editor, cx| { editor.change_selections(None, cx, |s| s.select_ranges([3..3])); editor.set_scroll_position(vec2f(0., 100.), cx); }); - deterministic.run_until_parked(); - assert_eq!( - editor_b1.read_with(cx_b, |editor, cx| editor.selections.ranges(cx)), - vec![3..3] - ); + editor_b1 + .condition(cx_b, |editor, cx| { + editor.selections.ranges(cx) == vec![3..3] + }) + .await; // After unfollowing, client B stops receiving updates from client A. - assert_eq!( - workspace_b.read_with(cx_b, |workspace, cx| workspace - .active_item(cx) - .unwrap() - .id()), - editor_b1.id() - ); workspace_b.update(cx_b, |workspace, cx| { workspace.unfollow(&workspace.active_pane().clone(), cx) }); diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 9e0335ef1b16b432b8a4fdefcfe4e909bbd954ac..4375056c9aa865905d23d98242aa02793bf8f97a 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -192,7 +192,7 @@ impl Server { .add_request_handler(Server::respond_to_contact_request) .add_request_handler(Server::follow) .add_message_handler(Server::unfollow) - .add_request_handler(Server::update_followers) + .add_message_handler(Server::update_followers) .add_message_handler(Server::update_diff_base) .add_request_handler(Server::get_private_user_info); @@ -1437,7 +1437,6 @@ impl Server { async fn update_followers( self: Arc, request: Message, - response: Response, ) -> Result<()> { let project_id = ProjectId::from_proto(request.payload.project_id); let project_connection_ids = self @@ -1465,7 +1464,6 @@ impl Server { )?; } } - response.send(proto::Ack {})?; Ok(()) } diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index 8a59818fa3d2bb95423465014456901daa945897..50f3c57f2a6b3c5bd9bc6798e468df7a541a2f07 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -229,7 +229,6 @@ request_messages!( (Test, Test), (UpdateBuffer, Ack), (UpdateDiagnosticSummary, Ack), - (UpdateFollowers, Ack), (UpdateParticipantLocation, Ack), (UpdateProject, Ack), (UpdateWorktree, Ack), diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index 5f14427feea53cc1b19e2674eabf374b9d4254be..2296741ed3c7f31768c2bd5857a463e18179c4fe 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -18,10 +18,7 @@ use collections::{hash_map, HashMap, HashSet}; use dock::{DefaultItemFactory, Dock, ToggleDockButton}; use drag_and_drop::DragAndDrop; use fs::{self, Fs}; -use futures::{ - channel::{mpsc, oneshot}, - FutureExt, StreamExt, -}; +use futures::{channel::oneshot, FutureExt, StreamExt}; use gpui::{ actions, elements::*, @@ -714,13 +711,14 @@ impl ItemHandle for ViewHandle { if let Some(followed_item) = self.to_followable_item_handle(cx) { if let Some(message) = followed_item.to_state_proto(cx) { - workspace.update_followers(proto::update_followers::Variant::CreateView( - proto::View { + workspace.update_followers( + proto::update_followers::Variant::CreateView(proto::View { id: followed_item.id() as u64, variant: Some(message), leader_id: workspace.leader_for_pane(&pane).map(|id| id.0), - }, - )); + }), + cx, + ); } } @@ -764,7 +762,7 @@ impl ItemHandle for ViewHandle { cx.after_window_update({ let pending_update = pending_update.clone(); let pending_update_scheduled = pending_update_scheduled.clone(); - move |this, _| { + move |this, cx| { pending_update_scheduled.store(false, SeqCst); this.update_followers( proto::update_followers::Variant::UpdateView( @@ -774,6 +772,7 @@ impl ItemHandle for ViewHandle { leader_id: leader_id.map(|id| id.0), }, ), + cx, ); } }); @@ -1082,11 +1081,9 @@ pub struct Workspace { leader_state: LeaderState, follower_states_by_leader: FollowerStatesByLeader, last_leaders_by_pane: HashMap, PeerId>, - follower_updates: mpsc::UnboundedSender, window_edited: bool, active_call: Option<(ModelHandle, Vec)>, _observe_current_user: Task<()>, - _update_followers: Task>, } #[derive(Default)] @@ -1169,34 +1166,6 @@ impl Workspace { } }); - let (follower_updates_tx, mut follower_updates_rx) = mpsc::unbounded(); - let _update_followers = cx.spawn_weak(|this, cx| async move { - while let Some(update) = follower_updates_rx.next().await { - let this = this.upgrade(&cx)?; - let update_followers = this.read_with(&cx, |this, cx| { - if let Some(project_id) = this.project.read(cx).remote_id() { - if this.leader_state.followers.is_empty() { - None - } else { - Some(this.client.request(proto::UpdateFollowers { - project_id, - follower_ids: - this.leader_state.followers.iter().map(|f| f.0).collect(), - variant: Some(update), - })) - } - } else { - None - } - }); - - if let Some(update_followers) = update_followers { - update_followers.await.log_err(); - } - } - None - }); - let handle = cx.handle(); let weak_handle = cx.weak_handle(); @@ -1255,12 +1224,10 @@ impl Workspace { project, leader_state: Default::default(), follower_states_by_leader: Default::default(), - follower_updates: follower_updates_tx, last_leaders_by_pane: Default::default(), window_edited: false, active_call, _observe_current_user, - _update_followers, }; this.project_remote_id_changed(this.project.read(cx).remote_id(), cx); cx.defer(|this, cx| this.update_window_title(cx)); @@ -2000,12 +1967,13 @@ impl Workspace { cx.notify(); } - self.update_followers(proto::update_followers::Variant::UpdateActiveView( - proto::UpdateActiveView { + self.update_followers( + proto::update_followers::Variant::UpdateActiveView(proto::UpdateActiveView { id: self.active_item(cx).map(|item| item.id() as u64), leader_id: self.leader_for_pane(&pane).map(|id| id.0), - }, - )); + }), + cx, + ); } fn handle_pane_event( @@ -2626,8 +2594,22 @@ impl Workspace { Ok(()) } - fn update_followers(&self, update: proto::update_followers::Variant) { - let _ = self.follower_updates.unbounded_send(update); + fn update_followers( + &self, + update: proto::update_followers::Variant, + cx: &AppContext, + ) -> Option<()> { + let project_id = self.project.read(cx).remote_id()?; + if !self.leader_state.followers.is_empty() { + self.client + .send(proto::UpdateFollowers { + project_id, + follower_ids: self.leader_state.followers.iter().map(|f| f.0).collect(), + variant: Some(update), + }) + .log_err(); + } + None } pub fn leader_for_pane(&self, pane: &ViewHandle) -> Option { From 4f39181c4cbd7b1845aa9ec3ff0fea59c80d4c86 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 17 Nov 2022 16:57:40 +0100 Subject: [PATCH 051/109] Revert "Don't replace newer diagnostics with older ones" This reverts commit 71eeeedc05f7ed6978f2ebfc6f169a7bc9cc8907. --- .../20221109000000_test_schema.sql | 1 - .../20221111092550_reconnection_support.sql | 1 - crates/collab/src/db.rs | 11 +++-------- crates/collab/src/integration_tests.rs | 12 ++++-------- crates/project/src/project.rs | 4 ---- crates/project/src/worktree.rs | 16 ++++++---------- crates/rpc/proto/zed.proto | 1 - 7 files changed, 13 insertions(+), 33 deletions(-) diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index bb216eb32d2e8beef1fda0ff55a4ac94a7cc7f4b..66925fddd55fba36464eef2fab7b4f30af75362f 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -82,7 +82,6 @@ CREATE TABLE "worktree_diagnostic_summaries" ( "language_server_id" INTEGER NOT NULL, "error_count" INTEGER NOT NULL, "warning_count" INTEGER NOT NULL, - "version" INTEGER NOT NULL, PRIMARY KEY(project_id, worktree_id, path), FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE ); diff --git a/crates/collab/migrations/20221111092550_reconnection_support.sql b/crates/collab/migrations/20221111092550_reconnection_support.sql index de29f0c878ca0b710f5796a70b49ce0720080418..2b8f7824cb4bea6a138fc983ee206d69464aedf0 100644 --- a/crates/collab/migrations/20221111092550_reconnection_support.sql +++ b/crates/collab/migrations/20221111092550_reconnection_support.sql @@ -44,7 +44,6 @@ CREATE TABLE "worktree_diagnostic_summaries" ( "language_server_id" INTEGER NOT NULL, "error_count" INTEGER NOT NULL, "warning_count" INTEGER NOT NULL, - "version" INTEGER NOT NULL, PRIMARY KEY(project_id, worktree_id, path), FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE ); diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 55c71ea92e0c53290bb1bb1de9f7746864089e1d..c97c82c656e022596d6a9bbaf7f51f63137d5df4 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1809,15 +1809,13 @@ where path, language_server_id, error_count, - warning_count, - version + warning_count ) - VALUES ($1, $2, $3, $4, $5, $6, $7) + VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (project_id, worktree_id, path) DO UPDATE SET language_server_id = excluded.language_server_id, error_count = excluded.error_count, - warning_count = excluded.warning_count, - version = excluded.version + warning_count = excluded.warning_count ", ) .bind(project_id) @@ -1826,7 +1824,6 @@ where .bind(summary.language_server_id as i64) .bind(summary.error_count as i32) .bind(summary.warning_count as i32) - .bind(summary.version as i32) .execute(&mut tx) .await?; @@ -2041,7 +2038,6 @@ where language_server_id: summary.language_server_id as u64, error_count: summary.error_count as u32, warning_count: summary.warning_count as u32, - version: summary.version as u32, }); } } @@ -2666,7 +2662,6 @@ struct WorktreeDiagnosticSummary { language_server_id: i64, error_count: i32, warning_count: i32, - version: i32, } id_type!(LanguageServerId); diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index d730b5d4e777640b3d3b643a31cb3b1225b195b6..1236af42cb05af4b544f74166284d34aa3e44739 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -2412,10 +2412,9 @@ async fn test_collaborating_with_diagnostics( path: Arc::from(Path::new("a.rs")), }, DiagnosticSummary { - language_server_id: 0, error_count: 1, warning_count: 0, - version: 2, + ..Default::default() }, )] ) @@ -2445,10 +2444,9 @@ async fn test_collaborating_with_diagnostics( path: Arc::from(Path::new("a.rs")), }, DiagnosticSummary { - language_server_id: 0, error_count: 1, warning_count: 0, - version: 2, + ..Default::default() }, )] ); @@ -2486,10 +2484,9 @@ async fn test_collaborating_with_diagnostics( path: Arc::from(Path::new("a.rs")), }, DiagnosticSummary { - language_server_id: 0, error_count: 1, warning_count: 1, - version: 3, + ..Default::default() }, )] ); @@ -2503,10 +2500,9 @@ async fn test_collaborating_with_diagnostics( path: Arc::from(Path::new("a.rs")), }, DiagnosticSummary { - language_server_id: 0, error_count: 1, warning_count: 1, - version: 3, + ..Default::default() }, )] ); diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 9d7323f989254cf9ef3728f07338e953b8b7397b..503ae8d4b24cc290e539121e50e2803939a9ecc7 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -223,7 +223,6 @@ pub struct DiagnosticSummary { pub language_server_id: usize, pub error_count: usize, pub warning_count: usize, - pub version: usize, } #[derive(Debug, Clone)] @@ -294,14 +293,12 @@ pub struct ProjectTransaction(pub HashMap, language::Transac impl DiagnosticSummary { fn new<'a, T: 'a>( language_server_id: usize, - version: usize, diagnostics: impl IntoIterator>, ) -> Self { let mut this = Self { language_server_id, error_count: 0, warning_count: 0, - version, }; for entry in diagnostics { @@ -327,7 +324,6 @@ impl DiagnosticSummary { language_server_id: self.language_server_id as u64, error_count: self.error_count as u32, warning_count: self.warning_count as u32, - version: self.version as u32, } } } diff --git a/crates/project/src/worktree.rs b/crates/project/src/worktree.rs index 04e77cf09af3a395896d5bba9d7ef2fb54ba1ccf..836ac55b661157f8c2f0297567b55143b8b26d2a 100644 --- a/crates/project/src/worktree.rs +++ b/crates/project/src/worktree.rs @@ -366,7 +366,6 @@ impl Worktree { Worktree::Remote(worktree) => &worktree.diagnostic_summaries, } .iter() - .filter(|(_, summary)| !summary.is_empty()) .map(|(path, summary)| (path.0.clone(), *summary)) } @@ -517,8 +516,7 @@ impl LocalWorktree { .diagnostic_summaries .remove(&PathKey(worktree_path.clone())) .unwrap_or_default(); - let new_summary = - DiagnosticSummary::new(language_server_id, old_summary.version + 1, &diagnostics); + let new_summary = DiagnosticSummary::new(language_server_id, &diagnostics); if !new_summary.is_empty() { self.diagnostic_summaries .insert(PathKey(worktree_path.clone()), new_summary); @@ -1108,17 +1106,15 @@ impl RemoteWorktree { path: Arc, summary: &proto::DiagnosticSummary, ) { - let old_summary = self.diagnostic_summaries.get(&PathKey(path.clone())); - let new_summary = DiagnosticSummary { + let summary = DiagnosticSummary { language_server_id: summary.language_server_id as usize, error_count: summary.error_count as usize, warning_count: summary.warning_count as usize, - version: summary.version as usize, }; - if old_summary.map_or(true, |old_summary| { - new_summary.version >= old_summary.version - }) { - self.diagnostic_summaries.insert(PathKey(path), new_summary); + if summary.is_empty() { + self.diagnostic_summaries.remove(&PathKey(path)); + } else { + self.diagnostic_summaries.insert(PathKey(path), summary); } } diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index b6d4b83b3b8e65c1c3c1a20ce7dc40c4452d31cb..30c1c89e8f8b393f96e13c96ad9ea42e14ff7a7e 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -652,7 +652,6 @@ message DiagnosticSummary { uint64 language_server_id = 2; uint32 error_count = 3; uint32 warning_count = 4; - uint32 version = 5; } message UpdateLanguageServer { From c34a5f3177ee471f631e5d657c7d62673971ca05 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 17 Nov 2022 17:11:06 +0100 Subject: [PATCH 052/109] Introduce a new `Session` struct to server message handlers Co-Authored-By: Nathan Sobo --- crates/collab/src/rpc.rs | 498 ++++++++++++++++++--------------------- 1 file changed, 232 insertions(+), 266 deletions(-) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 4375056c9aa865905d23d98242aa02793bf8f97a..19d45e221d5b39e0df416ea45364a128b5a4c774 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -68,21 +68,20 @@ lazy_static! { } type MessageHandler = Box< - dyn Send + Sync + Fn(Arc, UserId, Box) -> BoxFuture<'static, ()>, + dyn Send + Sync + Fn(Arc, Box, Session) -> BoxFuture<'static, ()>, >; -struct Message { - sender_user_id: UserId, - sender_connection_id: ConnectionId, - payload: T, -} - struct Response { server: Arc, receipt: Receipt, responded: Arc, } +struct Session { + user_id: UserId, + connection_id: ConnectionId, +} + impl Response { fn send(self, payload: R::Response) -> Result<()> { self.responded.store(true, SeqCst); @@ -201,13 +200,13 @@ impl Server { fn add_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(Arc, UserId, TypedEnvelope) -> Fut, + F: 'static + Send + Sync + Fn(Arc, TypedEnvelope, Session) -> Fut, Fut: 'static + Send + Future>, M: EnvelopedMessage, { let prev_handler = self.handlers.insert( TypeId::of::(), - Box::new(move |server, sender_user_id, envelope| { + Box::new(move |server, envelope, session| { let envelope = envelope.into_any().downcast::>().unwrap(); let span = info_span!( "handle message", @@ -219,7 +218,7 @@ impl Server { "message received" ); }); - let future = (handler)(server, sender_user_id, *envelope); + let future = (handler)(server, *envelope, session); async move { if let Err(error) = future.await { tracing::error!(%error, "error handling message"); @@ -237,19 +236,12 @@ impl Server { fn add_message_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(Arc, Message) -> Fut, + F: 'static + Send + Sync + Fn(Arc, M, Session) -> Fut, Fut: 'static + Send + Future>, M: EnvelopedMessage, { - self.add_handler(move |server, sender_user_id, envelope| { - handler( - server, - Message { - sender_user_id, - sender_connection_id: envelope.sender_id, - payload: envelope.payload, - }, - ) + self.add_handler(move |server, envelope, session| { + handler(server, envelope.payload, session) }); self } @@ -258,27 +250,22 @@ impl Server { /// a connection but we want to respond on the connection before anybody else can send on it. fn add_request_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(Arc, Message, Response) -> Fut, + F: 'static + Send + Sync + Fn(Arc, M, Response, Session) -> Fut, Fut: Send + Future>, M: RequestMessage, { let handler = Arc::new(handler); - self.add_handler(move |server, sender_user_id, envelope| { + self.add_handler(move |server, envelope, session| { let receipt = envelope.receipt(); let handler = handler.clone(); async move { - let request = Message { - sender_user_id, - sender_connection_id: envelope.sender_id, - payload: envelope.payload, - }; let responded = Arc::new(AtomicBool::default()); let response = Response { server: server.clone(), responded: responded.clone(), receipt, }; - match (handler)(server.clone(), request, response).await { + match (handler)(server.clone(), envelope.payload, response, session).await { Ok(()) => { if responded.load(std::sync::atomic::Ordering::SeqCst) { Ok(()) @@ -392,7 +379,11 @@ impl Server { let span_enter = span.enter(); if let Some(handler) = this.handlers.get(&message.payload_type_id()) { let is_background = message.is_background(); - let handle_message = (handler)(this.clone(), user_id, message); + let session = Session { + user_id, + connection_id, + }; + let handle_message = (handler)(this.clone(), message, session); drop(span_enter); let handle_message = handle_message.instrument(span); @@ -509,8 +500,9 @@ impl Server { async fn ping( self: Arc, - _: Message, + _: proto::Ping, response: Response, + _session: Session, ) -> Result<()> { response.send(proto::Ack {})?; Ok(()) @@ -518,13 +510,14 @@ impl Server { async fn create_room( self: Arc, - request: Message, + _request: proto::CreateRoom, response: Response, + session: Session, ) -> Result<()> { let room = self .app_state .db - .create_room(request.sender_user_id, request.sender_connection_id) + .create_room(session.user_id, session.connection_id) .await?; let live_kit_connection_info = @@ -535,10 +528,7 @@ impl Server { .trace_err() { if let Some(token) = live_kit - .room_token( - &room.live_kit_room, - &request.sender_connection_id.to_string(), - ) + .room_token(&room.live_kit_room, &session.connection_id.to_string()) .trace_err() { Some(proto::LiveKitConnectionInfo { @@ -559,29 +549,26 @@ impl Server { room: Some(room), live_kit_connection_info, })?; - self.update_user_contacts(request.sender_user_id).await?; + self.update_user_contacts(session.user_id).await?; Ok(()) } async fn join_room( self: Arc, - request: Message, + request: proto::JoinRoom, response: Response, + session: Session, ) -> Result<()> { let room = self .app_state .db .join_room( - RoomId::from_proto(request.payload.id), - request.sender_user_id, - request.sender_connection_id, + RoomId::from_proto(request.id), + session.user_id, + session.connection_id, ) .await?; - for connection_id in self - .store() - .await - .connection_ids_for_user(request.sender_user_id) - { + for connection_id in self.store().await.connection_ids_for_user(session.user_id) { self.peer .send(connection_id, proto::CallCanceled {}) .trace_err(); @@ -590,10 +577,7 @@ impl Server { let live_kit_connection_info = if let Some(live_kit) = self.app_state.live_kit_client.as_ref() { if let Some(token) = live_kit - .room_token( - &room.live_kit_room, - &request.sender_connection_id.to_string(), - ) + .room_token(&room.live_kit_room, &session.connection_id.to_string()) .trace_err() { Some(proto::LiveKitConnectionInfo { @@ -613,12 +597,16 @@ impl Server { live_kit_connection_info, })?; - self.update_user_contacts(request.sender_user_id).await?; + self.update_user_contacts(session.user_id).await?; Ok(()) } - async fn leave_room(self: Arc, message: Message) -> Result<()> { - self.leave_room_for_connection(message.sender_connection_id, message.sender_user_id) + async fn leave_room( + self: Arc, + _message: proto::LeaveRoom, + session: Session, + ) -> Result<()> { + self.leave_room_for_connection(session.connection_id, session.user_id) .await } @@ -707,17 +695,15 @@ impl Server { async fn call( self: Arc, - request: Message, + request: proto::Call, response: Response, + session: Session, ) -> Result<()> { - let room_id = RoomId::from_proto(request.payload.room_id); - let calling_user_id = request.sender_user_id; - let calling_connection_id = request.sender_connection_id; - let called_user_id = UserId::from_proto(request.payload.called_user_id); - let initial_project_id = request - .payload - .initial_project_id - .map(ProjectId::from_proto); + let room_id = RoomId::from_proto(request.room_id); + let calling_user_id = session.user_id; + let calling_connection_id = session.connection_id; + let called_user_id = UserId::from_proto(request.called_user_id); + let initial_project_id = request.initial_project_id.map(ProjectId::from_proto); if !self .app_state .db @@ -773,15 +759,16 @@ impl Server { async fn cancel_call( self: Arc, - request: Message, + request: proto::CancelCall, response: Response, + session: Session, ) -> Result<()> { - let called_user_id = UserId::from_proto(request.payload.called_user_id); - let room_id = RoomId::from_proto(request.payload.room_id); + let called_user_id = UserId::from_proto(request.called_user_id); + let room_id = RoomId::from_proto(request.room_id); let room = self .app_state .db - .cancel_call(Some(room_id), request.sender_connection_id, called_user_id) + .cancel_call(Some(room_id), session.connection_id, called_user_id) .await?; for connection_id in self.store().await.connection_ids_for_user(called_user_id) { self.peer @@ -795,41 +782,41 @@ impl Server { Ok(()) } - async fn decline_call(self: Arc, message: Message) -> Result<()> { - let room_id = RoomId::from_proto(message.payload.room_id); + async fn decline_call( + self: Arc, + message: proto::DeclineCall, + session: Session, + ) -> Result<()> { + let room_id = RoomId::from_proto(message.room_id); let room = self .app_state .db - .decline_call(Some(room_id), message.sender_user_id) + .decline_call(Some(room_id), session.user_id) .await?; - for connection_id in self - .store() - .await - .connection_ids_for_user(message.sender_user_id) - { + for connection_id in self.store().await.connection_ids_for_user(session.user_id) { self.peer .send(connection_id, proto::CallCanceled {}) .trace_err(); } self.room_updated(&room); - self.update_user_contacts(message.sender_user_id).await?; + self.update_user_contacts(session.user_id).await?; Ok(()) } async fn update_participant_location( self: Arc, - request: Message, + request: proto::UpdateParticipantLocation, response: Response, + session: Session, ) -> Result<()> { - let room_id = RoomId::from_proto(request.payload.room_id); + let room_id = RoomId::from_proto(request.room_id); let location = request - .payload .location .ok_or_else(|| anyhow!("invalid location"))?; let room = self .app_state .db - .update_room_participant_location(room_id, request.sender_connection_id, location) + .update_room_participant_location(room_id, session.connection_id, location) .await?; self.room_updated(&room); response.send(proto::Ack {})?; @@ -851,16 +838,17 @@ impl Server { async fn share_project( self: Arc, - request: Message, + request: proto::ShareProject, response: Response, + session: Session, ) -> Result<()> { let (project_id, room) = self .app_state .db .share_project( - RoomId::from_proto(request.payload.room_id), - request.sender_connection_id, - &request.payload.worktrees, + RoomId::from_proto(request.room_id), + session.connection_id, + &request.worktrees, ) .await?; response.send(proto::ShareProjectResponse { @@ -873,21 +861,20 @@ impl Server { async fn unshare_project( self: Arc, - message: Message, + message: proto::UnshareProject, + session: Session, ) -> Result<()> { - let project_id = ProjectId::from_proto(message.payload.project_id); + let project_id = ProjectId::from_proto(message.project_id); let (room, guest_connection_ids) = self .app_state .db - .unshare_project(project_id, message.sender_connection_id) + .unshare_project(project_id, session.connection_id) .await?; - broadcast( - message.sender_connection_id, - guest_connection_ids, - |conn_id| self.peer.send(conn_id, message.payload.clone()), - ); + broadcast(session.connection_id, guest_connection_ids, |conn_id| { + self.peer.send(conn_id, message.clone()) + }); self.room_updated(&room); Ok(()) @@ -926,26 +913,25 @@ impl Server { async fn join_project( self: Arc, - request: Message, + request: proto::JoinProject, response: Response, + session: Session, ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); - let guest_user_id = request.sender_user_id; + let project_id = ProjectId::from_proto(request.project_id); + let guest_user_id = session.user_id; tracing::info!(%project_id, "join project"); let (project, replica_id) = self .app_state .db - .join_project(project_id, request.sender_connection_id) + .join_project(project_id, session.connection_id) .await?; let collaborators = project .collaborators .iter() - .filter(|collaborator| { - collaborator.connection_id != request.sender_connection_id.0 as i32 - }) + .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32) .map(|collaborator| proto::Collaborator { peer_id: collaborator.connection_id as u32, replica_id: collaborator.replica_id.0 as u32, @@ -970,7 +956,7 @@ impl Server { proto::AddProjectCollaborator { project_id: project_id.to_proto(), collaborator: Some(proto::Collaborator { - peer_id: request.sender_connection_id.0, + peer_id: session.connection_id.0, replica_id: replica_id.0 as u32, user_id: guest_user_id.to_proto(), }), @@ -1005,14 +991,13 @@ impl Server { is_last_update: worktree.is_complete, }; for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) { - self.peer - .send(request.sender_connection_id, update.clone())?; + self.peer.send(session.connection_id, update.clone())?; } // Stream this worktree's diagnostics. for summary in worktree.diagnostic_summaries { self.peer.send( - request.sender_connection_id, + session.connection_id, proto::UpdateDiagnosticSummary { project_id: project_id.to_proto(), worktree_id: worktree.id.to_proto(), @@ -1024,7 +1009,7 @@ impl Server { for language_server in &project.language_servers { self.peer.send( - request.sender_connection_id, + session.connection_id, proto::UpdateLanguageServer { project_id: project_id.to_proto(), language_server_id: language_server.id, @@ -1040,9 +1025,13 @@ impl Server { Ok(()) } - async fn leave_project(self: Arc, request: Message) -> Result<()> { - let sender_id = request.sender_connection_id; - let project_id = ProjectId::from_proto(request.payload.project_id); + async fn leave_project( + self: Arc, + request: proto::LeaveProject, + session: Session, + ) -> Result<()> { + let sender_id = session.connection_id; + let project_id = ProjectId::from_proto(request.project_id); let project; { project = self @@ -1073,28 +1062,22 @@ impl Server { async fn update_project( self: Arc, - request: Message, + request: proto::UpdateProject, response: Response, + session: Session, ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); + let project_id = ProjectId::from_proto(request.project_id); let (room, guest_connection_ids) = self .app_state .db - .update_project( - project_id, - request.sender_connection_id, - &request.payload.worktrees, - ) + .update_project(project_id, session.connection_id, &request.worktrees) .await?; broadcast( - request.sender_connection_id, + session.connection_id, guest_connection_ids, |connection_id| { - self.peer.forward_send( - request.sender_connection_id, - connection_id, - request.payload.clone(), - ) + self.peer + .forward_send(session.connection_id, connection_id, request.clone()) }, ); self.room_updated(&room); @@ -1105,24 +1088,22 @@ impl Server { async fn update_worktree( self: Arc, - request: Message, + request: proto::UpdateWorktree, response: Response, + session: Session, ) -> Result<()> { let guest_connection_ids = self .app_state .db - .update_worktree(&request.payload, request.sender_connection_id) + .update_worktree(&request, session.connection_id) .await?; broadcast( - request.sender_connection_id, + session.connection_id, guest_connection_ids, |connection_id| { - self.peer.forward_send( - request.sender_connection_id, - connection_id, - request.payload.clone(), - ) + self.peer + .forward_send(session.connection_id, connection_id, request.clone()) }, ); response.send(proto::Ack {})?; @@ -1131,24 +1112,22 @@ impl Server { async fn update_diagnostic_summary( self: Arc, - request: Message, + request: proto::UpdateDiagnosticSummary, response: Response, + session: Session, ) -> Result<()> { let guest_connection_ids = self .app_state .db - .update_diagnostic_summary(&request.payload, request.sender_connection_id) + .update_diagnostic_summary(&request, session.connection_id) .await?; broadcast( - request.sender_connection_id, + session.connection_id, guest_connection_ids, |connection_id| { - self.peer.forward_send( - request.sender_connection_id, - connection_id, - request.payload.clone(), - ) + self.peer + .forward_send(session.connection_id, connection_id, request.clone()) }, ); @@ -1158,23 +1137,21 @@ impl Server { async fn start_language_server( self: Arc, - request: Message, + request: proto::StartLanguageServer, + session: Session, ) -> Result<()> { let guest_connection_ids = self .app_state .db - .start_language_server(&request.payload, request.sender_connection_id) + .start_language_server(&request, session.connection_id) .await?; broadcast( - request.sender_connection_id, + session.connection_id, guest_connection_ids, |connection_id| { - self.peer.forward_send( - request.sender_connection_id, - connection_id, - request.payload.clone(), - ) + self.peer + .forward_send(session.connection_id, connection_id, request.clone()) }, ); Ok(()) @@ -1182,23 +1159,21 @@ impl Server { async fn update_language_server( self: Arc, - request: Message, + request: proto::UpdateLanguageServer, + session: Session, ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); + let project_id = ProjectId::from_proto(request.project_id); let project_connection_ids = self .app_state .db - .project_connection_ids(project_id, request.sender_connection_id) + .project_connection_ids(project_id, session.connection_id) .await?; broadcast( - request.sender_connection_id, + session.connection_id, project_connection_ids, |connection_id| { - self.peer.forward_send( - request.sender_connection_id, - connection_id, - request.payload.clone(), - ) + self.peer + .forward_send(session.connection_id, connection_id, request.clone()) }, ); Ok(()) @@ -1206,17 +1181,18 @@ impl Server { async fn forward_project_request( self: Arc, - request: Message, + request: T, response: Response, + session: Session, ) -> Result<()> where T: EntityMessage + RequestMessage, { - let project_id = ProjectId::from_proto(request.payload.remote_entity_id()); + let project_id = ProjectId::from_proto(request.remote_entity_id()); let collaborators = self .app_state .db - .project_collaborators(project_id, request.sender_connection_id) + .project_collaborators(project_id, session.connection_id) .await?; let host = collaborators .iter() @@ -1226,9 +1202,9 @@ impl Server { let payload = self .peer .forward_request( - request.sender_connection_id, + session.connection_id, ConnectionId(host.connection_id as u32), - request.payload, + request, ) .await?; @@ -1238,14 +1214,15 @@ impl Server { async fn save_buffer( self: Arc, - request: Message, + request: proto::SaveBuffer, response: Response, + session: Session, ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); + let project_id = ProjectId::from_proto(request.project_id); let collaborators = self .app_state .db - .project_collaborators(project_id, request.sender_connection_id) + .project_collaborators(project_id, session.connection_id) .await?; let host = collaborators .into_iter() @@ -1254,21 +1231,16 @@ impl Server { let host_connection_id = ConnectionId(host.connection_id as u32); let response_payload = self .peer - .forward_request( - request.sender_connection_id, - host_connection_id, - request.payload.clone(), - ) + .forward_request(session.connection_id, host_connection_id, request.clone()) .await?; let mut collaborators = self .app_state .db - .project_collaborators(project_id, request.sender_connection_id) + .project_collaborators(project_id, session.connection_id) .await?; - collaborators.retain(|collaborator| { - collaborator.connection_id != request.sender_connection_id.0 as i32 - }); + collaborators + .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32); let project_connection_ids = collaborators .into_iter() .map(|collaborator| ConnectionId(collaborator.connection_id as u32)); @@ -1282,37 +1254,36 @@ impl Server { async fn create_buffer_for_peer( self: Arc, - request: Message, + request: proto::CreateBufferForPeer, + session: Session, ) -> Result<()> { self.peer.forward_send( - request.sender_connection_id, - ConnectionId(request.payload.peer_id), - request.payload, + session.connection_id, + ConnectionId(request.peer_id), + request, )?; Ok(()) } async fn update_buffer( self: Arc, - request: Message, + request: proto::UpdateBuffer, response: Response, + session: Session, ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); + let project_id = ProjectId::from_proto(request.project_id); let project_connection_ids = self .app_state .db - .project_connection_ids(project_id, request.sender_connection_id) + .project_connection_ids(project_id, session.connection_id) .await?; broadcast( - request.sender_connection_id, + session.connection_id, project_connection_ids, |connection_id| { - self.peer.forward_send( - request.sender_connection_id, - connection_id, - request.payload.clone(), - ) + self.peer + .forward_send(session.connection_id, connection_id, request.clone()) }, ); response.send(proto::Ack {})?; @@ -1321,24 +1292,22 @@ impl Server { async fn update_buffer_file( self: Arc, - request: Message, + request: proto::UpdateBufferFile, + session: Session, ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); + let project_id = ProjectId::from_proto(request.project_id); let project_connection_ids = self .app_state .db - .project_connection_ids(project_id, request.sender_connection_id) + .project_connection_ids(project_id, session.connection_id) .await?; broadcast( - request.sender_connection_id, + session.connection_id, project_connection_ids, |connection_id| { - self.peer.forward_send( - request.sender_connection_id, - connection_id, - request.payload.clone(), - ) + self.peer + .forward_send(session.connection_id, connection_id, request.clone()) }, ); Ok(()) @@ -1346,44 +1315,43 @@ impl Server { async fn buffer_reloaded( self: Arc, - request: Message, + request: proto::BufferReloaded, + session: Session, ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); + let project_id = ProjectId::from_proto(request.project_id); let project_connection_ids = self .app_state .db - .project_connection_ids(project_id, request.sender_connection_id) + .project_connection_ids(project_id, session.connection_id) .await?; broadcast( - request.sender_connection_id, + session.connection_id, project_connection_ids, |connection_id| { - self.peer.forward_send( - request.sender_connection_id, - connection_id, - request.payload.clone(), - ) + self.peer + .forward_send(session.connection_id, connection_id, request.clone()) }, ); Ok(()) } - async fn buffer_saved(self: Arc, request: Message) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); + async fn buffer_saved( + self: Arc, + request: proto::BufferSaved, + session: Session, + ) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); let project_connection_ids = self .app_state .db - .project_connection_ids(project_id, request.sender_connection_id) + .project_connection_ids(project_id, session.connection_id) .await?; broadcast( - request.sender_connection_id, + session.connection_id, project_connection_ids, |connection_id| { - self.peer.forward_send( - request.sender_connection_id, - connection_id, - request.payload.clone(), - ) + self.peer + .forward_send(session.connection_id, connection_id, request.clone()) }, ); Ok(()) @@ -1391,16 +1359,17 @@ impl Server { async fn follow( self: Arc, - request: Message, + request: proto::Follow, response: Response, + session: Session, ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); - let leader_id = ConnectionId(request.payload.leader_id); - let follower_id = request.sender_connection_id; + let project_id = ProjectId::from_proto(request.project_id); + let leader_id = ConnectionId(request.leader_id); + let follower_id = session.connection_id; let project_connection_ids = self .app_state .db - .project_connection_ids(project_id, request.sender_connection_id) + .project_connection_ids(project_id, session.connection_id) .await?; if !project_connection_ids.contains(&leader_id) { @@ -1409,7 +1378,7 @@ impl Server { let mut response_payload = self .peer - .forward_request(request.sender_connection_id, leader_id, request.payload) + .forward_request(session.connection_id, leader_id, request) .await?; response_payload .views @@ -1418,50 +1387,44 @@ impl Server { Ok(()) } - async fn unfollow(self: Arc, request: Message) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); - let leader_id = ConnectionId(request.payload.leader_id); + async fn unfollow(self: Arc, request: proto::Unfollow, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let leader_id = ConnectionId(request.leader_id); let project_connection_ids = self .app_state .db - .project_connection_ids(project_id, request.sender_connection_id) + .project_connection_ids(project_id, session.connection_id) .await?; if !project_connection_ids.contains(&leader_id) { Err(anyhow!("no such peer"))?; } self.peer - .forward_send(request.sender_connection_id, leader_id, request.payload)?; + .forward_send(session.connection_id, leader_id, request)?; Ok(()) } async fn update_followers( self: Arc, - request: Message, + request: proto::UpdateFollowers, + session: Session, ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); + let project_id = ProjectId::from_proto(request.project_id); let project_connection_ids = self .app_state .db - .project_connection_ids(project_id, request.sender_connection_id) + .project_connection_ids(project_id, session.connection_id) .await?; - let leader_id = request - .payload - .variant - .as_ref() - .and_then(|variant| match variant { - proto::update_followers::Variant::CreateView(payload) => payload.leader_id, - proto::update_followers::Variant::UpdateView(payload) => payload.leader_id, - proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id, - }); - for follower_id in &request.payload.follower_ids { + let leader_id = request.variant.as_ref().and_then(|variant| match variant { + proto::update_followers::Variant::CreateView(payload) => payload.leader_id, + proto::update_followers::Variant::UpdateView(payload) => payload.leader_id, + proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id, + }); + for follower_id in &request.follower_ids { let follower_id = ConnectionId(*follower_id); if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id { - self.peer.forward_send( - request.sender_connection_id, - follower_id, - request.payload.clone(), - )?; + self.peer + .forward_send(session.connection_id, follower_id, request.clone())?; } } Ok(()) @@ -1469,11 +1432,11 @@ impl Server { async fn get_users( self: Arc, - request: Message, + request: proto::GetUsers, response: Response, + _session: Session, ) -> Result<()> { let user_ids = request - .payload .user_ids .into_iter() .map(UserId::from_proto) @@ -1496,10 +1459,11 @@ impl Server { async fn fuzzy_search_users( self: Arc, - request: Message, + request: proto::FuzzySearchUsers, response: Response, + session: Session, ) -> Result<()> { - let query = request.payload.query; + let query = request.query; let db = &self.app_state.db; let users = match query.len() { 0 => vec![], @@ -1512,7 +1476,7 @@ impl Server { }; let users = users .into_iter() - .filter(|user| user.id != request.sender_user_id) + .filter(|user| user.id != session.user_id) .map(|user| proto::User { id: user.id.to_proto(), avatar_url: format!("https://github.com/{}.png?size=128", user.github_login), @@ -1525,11 +1489,12 @@ impl Server { async fn request_contact( self: Arc, - request: Message, + request: proto::RequestContact, response: Response, + session: Session, ) -> Result<()> { - let requester_id = request.sender_user_id; - let responder_id = UserId::from_proto(request.payload.responder_id); + let requester_id = session.user_id; + let responder_id = UserId::from_proto(request.responder_id); if requester_id == responder_id { return Err(anyhow!("cannot add yourself as a contact"))?; } @@ -1564,18 +1529,19 @@ impl Server { async fn respond_to_contact_request( self: Arc, - request: Message, + request: proto::RespondToContactRequest, response: Response, + session: Session, ) -> Result<()> { - let responder_id = request.sender_user_id; - let requester_id = UserId::from_proto(request.payload.requester_id); - if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 { + let responder_id = session.user_id; + let requester_id = UserId::from_proto(request.requester_id); + if request.response == proto::ContactRequestResponse::Dismiss as i32 { self.app_state .db .dismiss_contact_notification(responder_id, requester_id) .await?; } else { - let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32; + let accept = request.response == proto::ContactRequestResponse::Accept as i32; self.app_state .db .respond_to_contact_request(responder_id, requester_id, accept) @@ -1618,11 +1584,12 @@ impl Server { async fn remove_contact( self: Arc, - request: Message, + request: proto::RemoveContact, response: Response, + session: Session, ) -> Result<()> { - let requester_id = request.sender_user_id; - let responder_id = UserId::from_proto(request.payload.user_id); + let requester_id = session.user_id; + let responder_id = UserId::from_proto(request.user_id); self.app_state .db .remove_contact(requester_id, responder_id) @@ -1652,23 +1619,21 @@ impl Server { async fn update_diff_base( self: Arc, - request: Message, + request: proto::UpdateDiffBase, + session: Session, ) -> Result<()> { - let project_id = ProjectId::from_proto(request.payload.project_id); + let project_id = ProjectId::from_proto(request.project_id); let project_connection_ids = self .app_state .db - .project_connection_ids(project_id, request.sender_connection_id) + .project_connection_ids(project_id, session.connection_id) .await?; broadcast( - request.sender_connection_id, + session.connection_id, project_connection_ids, |connection_id| { - self.peer.forward_send( - request.sender_connection_id, - connection_id, - request.payload.clone(), - ) + self.peer + .forward_send(session.connection_id, connection_id, request.clone()) }, ); Ok(()) @@ -1676,18 +1641,19 @@ impl Server { async fn get_private_user_info( self: Arc, - request: Message, + _request: proto::GetPrivateUserInfo, response: Response, + session: Session, ) -> Result<()> { let metrics_id = self .app_state .db - .get_user_metrics_id(request.sender_user_id) + .get_user_metrics_id(session.user_id) .await?; let user = self .app_state .db - .get_user_by_id(request.sender_user_id) + .get_user_by_id(session.user_id) .await? .ok_or_else(|| anyhow!("user not found"))?; response.send(proto::GetPrivateUserInfoResponse { From 0a4517f97e55ea41d6a27996a2948de669887416 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 17 Nov 2022 17:30:26 +0100 Subject: [PATCH 053/109] WIP: Introduce a `db` field to `Session` Co-Authored-By: Nathan Sobo --- Cargo.lock | 6 +++--- crates/collab/Cargo.toml | 1 - crates/collab/src/rpc.rs | 21 ++++++++++++++++++--- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1cceb9f99cbbcd8d9066634fd836f12e4ecab11b..b6f86980ae5f792a9d22fb6936599b6a5ab9cf4b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -266,11 +266,12 @@ dependencies = [ [[package]] name = "async-lock" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e97a171d191782fba31bb902b14ad94e24a68145032b7eedf871ab0bc0d077b6" +checksum = "c8101efe8695a6c17e02911402145357e718ac92d3ff88ae8419e84b1707b685" dependencies = [ "event-listener", + "futures-lite", ] [[package]] @@ -1031,7 +1032,6 @@ name = "collab" version = "0.2.2" dependencies = [ "anyhow", - "async-trait", "async-tungstenite", "axum", "axum-extra", diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 7456cb5598f64bd497fd2b73252ac40219e439b6..f04918605ff6a1e6e4911cbfeb01e7de045b6525 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -19,7 +19,6 @@ rpc = { path = "../rpc" } util = { path = "../util" } anyhow = "1.0.40" -async-trait = "0.1.50" async-tungstenite = "0.16" axum = { version = "0.5", features = ["json", "headers", "ws"] } axum-extra = { version = "0.3", features = ["erased-json"] } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 19d45e221d5b39e0df416ea45364a128b5a4c774..0c559239f5b74a3fed044aaa449d6d76b72804c8 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -2,7 +2,7 @@ mod store; use crate::{ auth, - db::{self, ProjectId, RoomId, User, UserId}, + db::{self, DefaultDb, ProjectId, RoomId, User, UserId}, AppState, Result, }; use anyhow::anyhow; @@ -80,6 +80,17 @@ struct Response { struct Session { user_id: UserId, connection_id: ConnectionId, + db: Arc>, +} + +struct DbHandle(Arc); + +impl Deref for DbHandle { + type Target = DefaultDb; + + fn deref(&self) -> &Self::Target { + self.0.as_ref() + } } impl Response { @@ -352,6 +363,8 @@ impl Server { let handle_io = handle_io.fuse(); futures::pin_mut!(handle_io); + let db = Arc::new(Mutex::new(DbHandle(this.app_state.db.clone()))); + // Handlers for foreground messages are pushed into the following `FuturesUnordered`. // This prevents deadlocks when e.g., client A performs a request to client B and // client B performs a request to client A. If both clients stop processing further @@ -382,6 +395,7 @@ impl Server { let session = Session { user_id, connection_id, + db: db.clone(), }; let handle_message = (handler)(this.clone(), message, session); drop(span_enter); @@ -1409,9 +1423,10 @@ impl Server { session: Session, ) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); - let project_connection_ids = self - .app_state + let project_connection_ids = session .db + .lock() + .await .project_connection_ids(project_id, session.connection_id) .await?; From 6c83be3f89328f1e89670cec038ff6ff9b16e98c Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 17 Nov 2022 18:46:39 +0100 Subject: [PATCH 054/109] Remove obsolete code from `Store` --- crates/collab/src/db.rs | 15 +++ crates/collab/src/main.rs | 53 --------- crates/collab/src/rpc.rs | 60 +++++----- crates/collab/src/rpc/store.rs | 205 ++------------------------------- 4 files changed, 58 insertions(+), 275 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index c97c82c656e022596d6a9bbaf7f51f63137d5df4..6cb53738817c567a887335ad3a1f41c5c24be859 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1464,6 +1464,21 @@ where // projects + pub async fn project_count_excluding_admins(&self) -> Result { + self.transact(|mut tx| async move { + Ok(sqlx::query_scalar::<_, i32>( + " + SELECT COUNT(*) + FROM projects, users + WHERE projects.host_user_id = users.id AND users.admin IS FALSE + ", + ) + .fetch_one(&mut tx) + .await? as usize) + }) + .await + } + pub async fn share_project( &self, expected_room_id: RoomId, diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index dc98a2ee6855c072f5adc9ed95dbad38626eca48..20fae38c161e01fd325a05cd2868f437ccef5363 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -9,7 +9,6 @@ mod db_tests; #[cfg(test)] mod integration_tests; -use crate::rpc::ResultExt as _; use anyhow::anyhow; use axum::{routing::get, Router}; use collab::{Error, Result}; @@ -20,9 +19,7 @@ use std::{ net::{SocketAddr, TcpListener}, path::{Path, PathBuf}, sync::Arc, - time::Duration, }; -use tokio::signal; use tracing_log::LogTracer; use tracing_subscriber::{filter::EnvFilter, fmt::format::JsonFields, Layer}; use util::ResultExt; @@ -129,7 +126,6 @@ async fn main() -> Result<()> { axum::Server::from_tcp(listener)? .serve(app.into_make_service_with_connect_info::()) - .with_graceful_shutdown(graceful_shutdown(rpc_server, state)) .await?; } _ => { @@ -174,52 +170,3 @@ pub fn init_tracing(config: &Config) -> Option<()> { None } - -async fn graceful_shutdown(rpc_server: Arc, state: Arc) { - let ctrl_c = async { - signal::ctrl_c() - .await - .expect("failed to install Ctrl+C handler"); - }; - - #[cfg(unix)] - let terminate = async { - signal::unix::signal(signal::unix::SignalKind::terminate()) - .expect("failed to install signal handler") - .recv() - .await; - }; - - #[cfg(not(unix))] - let terminate = std::future::pending::<()>(); - - tokio::select! { - _ = ctrl_c => {}, - _ = terminate => {}, - } - - if let Some(live_kit) = state.live_kit_client.as_ref() { - let deletions = rpc_server - .store() - .await - .rooms() - .values() - .map(|room| { - let name = room.live_kit_room.clone(); - async { - live_kit.delete_room(name).await.trace_err(); - } - }) - .collect::>(); - - tracing::info!("deleting all live-kit rooms"); - if let Err(_) = tokio::time::timeout( - Duration::from_secs(10), - futures::future::join_all(deletions), - ) - .await - { - tracing::error!("timed out waiting for live-kit room deletion"); - } - } -} diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 0c559239f5b74a3fed044aaa449d6d76b72804c8..58870163f50f349082636e9753171bc80560ea7f 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -49,7 +49,7 @@ use std::{ }, time::Duration, }; -pub use store::{Store, Worktree}; +pub use store::Store; use tokio::{ sync::{Mutex, MutexGuard}, time::Sleep, @@ -437,7 +437,7 @@ impl Server { let decline_calls = { let mut store = self.store().await; store.remove_connection(connection_id)?; - let mut connections = store.connection_ids_for_user(user_id); + let mut connections = store.user_connection_ids(user_id); connections.next().is_none() }; @@ -470,7 +470,7 @@ impl Server { if let Some(code) = &user.invite_code { let store = self.store().await; let invitee_contact = store.contact_for_user(invitee_id, true, false); - for connection_id in store.connection_ids_for_user(inviter_id) { + for connection_id in store.user_connection_ids(inviter_id) { self.peer.send( connection_id, proto::UpdateContacts { @@ -495,7 +495,7 @@ impl Server { if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? { if let Some(invite_code) = &user.invite_code { let store = self.store().await; - for connection_id in store.connection_ids_for_user(user_id) { + for connection_id in store.user_connection_ids(user_id) { self.peer.send( connection_id, proto::UpdateInviteInfo { @@ -582,7 +582,7 @@ impl Server { session.connection_id, ) .await?; - for connection_id in self.store().await.connection_ids_for_user(session.user_id) { + for connection_id in self.store().await.user_connection_ids(session.user_id) { self.peer .send(connection_id, proto::CallCanceled {}) .trace_err(); @@ -674,7 +674,7 @@ impl Server { { let store = self.store().await; for canceled_user_id in left_room.canceled_calls_to_user_ids { - for connection_id in store.connection_ids_for_user(canceled_user_id) { + for connection_id in store.user_connection_ids(canceled_user_id) { self.peer .send(connection_id, proto::CallCanceled {}) .trace_err(); @@ -744,7 +744,7 @@ impl Server { let mut calls = self .store() .await - .connection_ids_for_user(called_user_id) + .user_connection_ids(called_user_id) .map(|connection_id| self.peer.request(connection_id, incoming_call.clone())) .collect::>(); @@ -784,7 +784,7 @@ impl Server { .db .cancel_call(Some(room_id), session.connection_id, called_user_id) .await?; - for connection_id in self.store().await.connection_ids_for_user(called_user_id) { + for connection_id in self.store().await.user_connection_ids(called_user_id) { self.peer .send(connection_id, proto::CallCanceled {}) .trace_err(); @@ -807,7 +807,7 @@ impl Server { .db .decline_call(Some(room_id), session.user_id) .await?; - for connection_id in self.store().await.connection_ids_for_user(session.user_id) { + for connection_id in self.store().await.user_connection_ids(session.user_id) { self.peer .send(connection_id, proto::CallCanceled {}) .trace_err(); @@ -905,7 +905,7 @@ impl Server { .. } = contact { - for contact_conn_id in store.connection_ids_for_user(contact_user_id) { + for contact_conn_id in store.user_connection_ids(contact_user_id) { self.peer .send( contact_conn_id, @@ -1522,7 +1522,7 @@ impl Server { // Update outgoing contact requests of requester let mut update = proto::UpdateContacts::default(); update.outgoing_requests.push(responder_id.to_proto()); - for connection_id in self.store().await.connection_ids_for_user(requester_id) { + for connection_id in self.store().await.user_connection_ids(requester_id) { self.peer.send(connection_id, update.clone())?; } @@ -1534,7 +1534,7 @@ impl Server { requester_id: requester_id.to_proto(), should_notify: true, }); - for connection_id in self.store().await.connection_ids_for_user(responder_id) { + for connection_id in self.store().await.user_connection_ids(responder_id) { self.peer.send(connection_id, update.clone())?; } @@ -1574,7 +1574,7 @@ impl Server { update .remove_incoming_requests .push(requester_id.to_proto()); - for connection_id in store.connection_ids_for_user(responder_id) { + for connection_id in store.user_connection_ids(responder_id) { self.peer.send(connection_id, update.clone())?; } @@ -1588,7 +1588,7 @@ impl Server { update .remove_outgoing_requests .push(responder_id.to_proto()); - for connection_id in store.connection_ids_for_user(requester_id) { + for connection_id in store.user_connection_ids(requester_id) { self.peer.send(connection_id, update.clone())?; } } @@ -1615,7 +1615,7 @@ impl Server { update .remove_outgoing_requests .push(responder_id.to_proto()); - for connection_id in self.store().await.connection_ids_for_user(requester_id) { + for connection_id in self.store().await.user_connection_ids(requester_id) { self.peer.send(connection_id, update.clone())?; } @@ -1624,7 +1624,7 @@ impl Server { update .remove_incoming_requests .push(requester_id.to_proto()); - for connection_id in self.store().await.connection_ids_for_user(responder_id) { + for connection_id in self.store().await.user_connection_ids(responder_id) { self.peer.send(connection_id, update.clone())?; } @@ -1819,21 +1819,25 @@ pub async fn handle_websocket_request( }) } -pub async fn handle_metrics(Extension(server): Extension>) -> axum::response::Response { - let metrics = server.store().await.metrics(); - METRIC_CONNECTIONS.set(metrics.connections as _); - METRIC_SHARED_PROJECTS.set(metrics.shared_projects as _); +pub async fn handle_metrics(Extension(server): Extension>) -> Result { + let connections = server + .store() + .await + .connections() + .filter(|connection| !connection.admin) + .count(); + + METRIC_CONNECTIONS.set(connections as _); + + let shared_projects = server.app_state.db.project_count_excluding_admins().await?; + METRIC_SHARED_PROJECTS.set(shared_projects as _); let encoder = prometheus::TextEncoder::new(); let metric_families = prometheus::gather(); - match encoder.encode_to_string(&metric_families) { - Ok(string) => (StatusCode::OK, string).into_response(), - Err(error) => ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("failed to encode metrics {:?}", error), - ) - .into_response(), - } + let encoded_metrics = encoder + .encode_to_string(&metric_families) + .map_err(|err| anyhow!("{}", err))?; + Ok(encoded_metrics) } fn to_axum_message(message: TungsteniteMessage) -> AxumMessage { diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index 1aa9c709b733dcc317f7274854603b36a8c6bf51..2bb6d89f401a0274c3ac83b70eaa9cd192c882d1 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -1,111 +1,32 @@ -use crate::db::{self, ProjectId, UserId}; +use crate::db::{self, UserId}; use anyhow::{anyhow, Result}; -use collections::{BTreeMap, BTreeSet, HashMap, HashSet}; +use collections::{BTreeMap, HashSet}; use rpc::{proto, ConnectionId}; use serde::Serialize; -use std::path::PathBuf; use tracing::instrument; -pub type RoomId = u64; - #[derive(Default, Serialize)] pub struct Store { - connections: BTreeMap, + connections: BTreeMap, connected_users: BTreeMap, - next_room_id: RoomId, - rooms: BTreeMap, - projects: BTreeMap, } #[derive(Default, Serialize)] struct ConnectedUser { connection_ids: HashSet, - active_call: Option, } #[derive(Serialize)] -struct ConnectionState { - user_id: UserId, - admin: bool, - projects: BTreeSet, -} - -#[derive(Copy, Clone, Eq, PartialEq, Serialize)] -pub struct Call { - pub calling_user_id: UserId, - pub room_id: RoomId, - pub connection_id: Option, - pub initial_project_id: Option, -} - -#[derive(Serialize)] -pub struct Project { - pub id: ProjectId, - pub room_id: RoomId, - pub host_connection_id: ConnectionId, - pub host: Collaborator, - pub guests: HashMap, - pub active_replica_ids: HashSet, - pub worktrees: BTreeMap, - pub language_servers: Vec, -} - -#[derive(Serialize)] -pub struct Collaborator { - pub replica_id: ReplicaId, +pub struct Connection { pub user_id: UserId, pub admin: bool, } -#[derive(Default, Serialize)] -pub struct Worktree { - pub abs_path: PathBuf, - pub root_name: String, - pub visible: bool, - #[serde(skip)] - pub entries: BTreeMap, - #[serde(skip)] - pub diagnostic_summaries: BTreeMap, - pub scan_id: u64, - pub is_complete: bool, -} - -pub type ReplicaId = u16; - -#[derive(Copy, Clone)] -pub struct Metrics { - pub connections: usize, - pub shared_projects: usize, -} - impl Store { - pub fn metrics(&self) -> Metrics { - let connections = self.connections.values().filter(|c| !c.admin).count(); - let mut shared_projects = 0; - for project in self.projects.values() { - if let Some(connection) = self.connections.get(&project.host_connection_id) { - if !connection.admin { - shared_projects += 1; - } - } - } - - Metrics { - connections, - shared_projects, - } - } - #[instrument(skip(self))] pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId, admin: bool) { - self.connections.insert( - connection_id, - ConnectionState { - user_id, - admin, - projects: Default::default(), - }, - ); + self.connections + .insert(connection_id, Connection { user_id, admin }); let connected_user = self.connected_users.entry(user_id).or_default(); connected_user.connection_ids.insert(connection_id); } @@ -127,10 +48,11 @@ impl Store { Ok(()) } - pub fn connection_ids_for_user( - &self, - user_id: UserId, - ) -> impl Iterator + '_ { + pub fn connections(&self) -> impl Iterator { + self.connections.values() + } + + pub fn user_connection_ids(&self, user_id: UserId) -> impl Iterator + '_ { self.connected_users .get(&user_id) .into_iter() @@ -197,35 +119,9 @@ impl Store { } } - pub fn rooms(&self) -> &BTreeMap { - &self.rooms - } - #[cfg(test)] pub fn check_invariants(&self) { for (connection_id, connection) in &self.connections { - for project_id in &connection.projects { - let project = &self.projects.get(project_id).unwrap(); - if project.host_connection_id != *connection_id { - assert!(project.guests.contains_key(connection_id)); - } - - for (worktree_id, worktree) in project.worktrees.iter() { - let mut paths = HashMap::default(); - for entry in worktree.entries.values() { - let prev_entry = paths.insert(&entry.path, entry); - assert_eq!( - prev_entry, - None, - "worktree {:?}, duplicate path for entries {:?} and {:?}", - worktree_id, - prev_entry.unwrap(), - entry - ); - } - } - } - assert!(self .connected_users .get(&connection.user_id) @@ -241,85 +137,6 @@ impl Store { *user_id ); } - - if let Some(active_call) = state.active_call.as_ref() { - if let Some(active_call_connection_id) = active_call.connection_id { - assert!( - state.connection_ids.contains(&active_call_connection_id), - "call is active on a dead connection" - ); - assert!( - state.connection_ids.contains(&active_call_connection_id), - "call is active on a dead connection" - ); - } - } - } - - for (room_id, room) in &self.rooms { - // for pending_user_id in &room.pending_participant_user_ids { - // assert!( - // self.connected_users - // .contains_key(&UserId::from_proto(*pending_user_id)), - // "call is active on a user that has disconnected" - // ); - // } - - for participant in &room.participants { - assert!( - self.connections - .contains_key(&ConnectionId(participant.peer_id)), - "room {} contains participant {:?} that has disconnected", - room_id, - participant - ); - - for participant_project in &participant.projects { - let project = &self.projects[&ProjectId::from_proto(participant_project.id)]; - assert_eq!( - project.room_id, *room_id, - "project was shared on a different room" - ); - } - } - - // assert!( - // !room.pending_participant_user_ids.is_empty() || !room.participants.is_empty(), - // "room can't be empty" - // ); - } - - for (project_id, project) in &self.projects { - let host_connection = self.connections.get(&project.host_connection_id).unwrap(); - assert!(host_connection.projects.contains(project_id)); - - for guest_connection_id in project.guests.keys() { - let guest_connection = self.connections.get(guest_connection_id).unwrap(); - assert!(guest_connection.projects.contains(project_id)); - } - assert_eq!(project.active_replica_ids.len(), project.guests.len()); - assert_eq!( - project.active_replica_ids, - project - .guests - .values() - .map(|guest| guest.replica_id) - .collect::>(), - ); - - let room = &self.rooms[&project.room_id]; - let room_participant = room - .participants - .iter() - .find(|participant| participant.peer_id == project.host_connection_id.0) - .unwrap(); - assert!( - room_participant - .projects - .iter() - .any(|project| project.id == project_id.to_proto()), - "project was not shared in room" - ); } } } From 44bb2ce024a2b9afe747023f6a6a01068eccef67 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 17 Nov 2022 19:03:50 +0100 Subject: [PATCH 055/109] Rename `Store` to `ConnectionPool` --- crates/collab/src/integration_tests.rs | 21 +-- crates/collab/src/rpc.rs | 167 +++++++++++++----- .../src/rpc/{store.rs => connection_pool.rs} | 57 +----- 3 files changed, 133 insertions(+), 112 deletions(-) rename crates/collab/src/rpc/{store.rs => connection_pool.rs} (64%) diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index 1236af42cb05af4b544f74166284d34aa3e44739..006598a6b191e593c7934d145a3c146da0a7c496 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -1,5 +1,5 @@ use crate::{ - db::{NewUserParams, SqliteTestDb as TestDb, UserId}, + db::{self, NewUserParams, SqliteTestDb as TestDb, UserId}, rpc::{Executor, Server}, AppState, }; @@ -5469,18 +5469,15 @@ async fn test_random_collaboration( } for user_id in &user_ids { let contacts = server.app_state.db.get_contacts(*user_id).await.unwrap(); - let contacts = server - .store - .lock() - .await - .build_initial_contacts_update(contacts) - .contacts; + let pool = server.connection_pool.lock().await; for contact in contacts { - if contact.online { - assert_ne!( - contact.user_id, removed_guest_id.0 as u64, - "removed guest is still a contact of another peer" - ); + if let db::Contact::Accepted { user_id, .. } = contact { + if pool.is_user_online(user_id) { + assert_ne!( + user_id, removed_guest_id, + "removed guest is still a contact of another peer" + ); + } } } } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 58870163f50f349082636e9753171bc80560ea7f..175e3604c04acc522348a6f2c92e7fdb53b16599 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1,4 +1,4 @@ -mod store; +mod connection_pool; use crate::{ auth, @@ -23,6 +23,7 @@ use axum::{ Extension, Router, TypedHeader, }; use collections::{HashMap, HashSet}; +pub use connection_pool::ConnectionPool; use futures::{ channel::oneshot, future::{self, BoxFuture}, @@ -49,7 +50,6 @@ use std::{ }, time::Duration, }; -pub use store::Store; use tokio::{ sync::{Mutex, MutexGuard}, time::Sleep, @@ -103,7 +103,7 @@ impl Response { pub struct Server { peer: Arc, - pub(crate) store: Mutex, + pub(crate) connection_pool: Mutex, app_state: Arc, handlers: HashMap, } @@ -117,8 +117,8 @@ pub trait Executor: Send + Clone { #[derive(Clone)] pub struct RealExecutor; -pub(crate) struct StoreGuard<'a> { - guard: MutexGuard<'a, Store>, +pub(crate) struct ConnectionPoolGuard<'a> { + guard: MutexGuard<'a, ConnectionPool>, _not_send: PhantomData>, } @@ -126,7 +126,7 @@ pub(crate) struct StoreGuard<'a> { pub struct ServerSnapshot<'a> { peer: &'a Peer, #[serde(serialize_with = "serialize_deref")] - store: StoreGuard<'a>, + connection_pool: ConnectionPoolGuard<'a>, } pub fn serialize_deref(value: &T, serializer: S) -> Result @@ -143,7 +143,7 @@ impl Server { let mut server = Self { peer: Peer::new(), app_state, - store: Default::default(), + connection_pool: Default::default(), handlers: Default::default(), }; @@ -257,8 +257,6 @@ impl Server { self } - /// Handle a request while holding a lock to the store. This is useful when we're registering - /// a connection but we want to respond on the connection before anybody else can send on it. fn add_request_handler(&mut self, handler: F) -> &mut Self where F: 'static + Send + Sync + Fn(Arc, M, Response, Session) -> Fut, @@ -342,9 +340,9 @@ impl Server { ).await?; { - let mut store = this.store().await; - store.add_connection(connection_id, user_id, user.admin); - this.peer.send(connection_id, store.build_initial_contacts_update(contacts))?; + let mut pool = this.connection_pool().await; + pool.add_connection(connection_id, user_id, user.admin); + this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?; if let Some((code, count)) = invite_code { this.peer.send(connection_id, proto::UpdateInviteInfo { @@ -435,9 +433,9 @@ impl Server { ) -> Result<()> { self.peer.disconnect(connection_id); let decline_calls = { - let mut store = self.store().await; - store.remove_connection(connection_id)?; - let mut connections = store.user_connection_ids(user_id); + let mut pool = self.connection_pool().await; + pool.remove_connection(connection_id)?; + let mut connections = pool.user_connection_ids(user_id); connections.next().is_none() }; @@ -468,9 +466,9 @@ impl Server { ) -> Result<()> { if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? { if let Some(code) = &user.invite_code { - let store = self.store().await; - let invitee_contact = store.contact_for_user(invitee_id, true, false); - for connection_id in store.user_connection_ids(inviter_id) { + let pool = self.connection_pool().await; + let invitee_contact = contact_for_user(invitee_id, true, false, &pool); + for connection_id in pool.user_connection_ids(inviter_id) { self.peer.send( connection_id, proto::UpdateContacts { @@ -494,8 +492,8 @@ impl Server { pub async fn invite_count_updated(self: &Arc, user_id: UserId) -> Result<()> { if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? { if let Some(invite_code) = &user.invite_code { - let store = self.store().await; - for connection_id in store.user_connection_ids(user_id) { + let pool = self.connection_pool().await; + for connection_id in pool.user_connection_ids(user_id) { self.peer.send( connection_id, proto::UpdateInviteInfo { @@ -582,7 +580,11 @@ impl Server { session.connection_id, ) .await?; - for connection_id in self.store().await.user_connection_ids(session.user_id) { + for connection_id in self + .connection_pool() + .await + .user_connection_ids(session.user_id) + { self.peer .send(connection_id, proto::CallCanceled {}) .trace_err(); @@ -672,9 +674,9 @@ impl Server { self.room_updated(&left_room.room); { - let store = self.store().await; + let pool = self.connection_pool().await; for canceled_user_id in left_room.canceled_calls_to_user_ids { - for connection_id in store.user_connection_ids(canceled_user_id) { + for connection_id in pool.user_connection_ids(canceled_user_id) { self.peer .send(connection_id, proto::CallCanceled {}) .trace_err(); @@ -742,7 +744,7 @@ impl Server { self.update_user_contacts(called_user_id).await?; let mut calls = self - .store() + .connection_pool() .await .user_connection_ids(called_user_id) .map(|connection_id| self.peer.request(connection_id, incoming_call.clone())) @@ -784,7 +786,11 @@ impl Server { .db .cancel_call(Some(room_id), session.connection_id, called_user_id) .await?; - for connection_id in self.store().await.user_connection_ids(called_user_id) { + for connection_id in self + .connection_pool() + .await + .user_connection_ids(called_user_id) + { self.peer .send(connection_id, proto::CallCanceled {}) .trace_err(); @@ -807,7 +813,11 @@ impl Server { .db .decline_call(Some(room_id), session.user_id) .await?; - for connection_id in self.store().await.user_connection_ids(session.user_id) { + for connection_id in self + .connection_pool() + .await + .user_connection_ids(session.user_id) + { self.peer .send(connection_id, proto::CallCanceled {}) .trace_err(); @@ -897,15 +907,15 @@ impl Server { async fn update_user_contacts(self: &Arc, user_id: UserId) -> Result<()> { let contacts = self.app_state.db.get_contacts(user_id).await?; let busy = self.app_state.db.is_user_busy(user_id).await?; - let store = self.store().await; - let updated_contact = store.contact_for_user(user_id, false, busy); + let pool = self.connection_pool().await; + let updated_contact = contact_for_user(user_id, false, busy, &pool); for contact in contacts { if let db::Contact::Accepted { user_id: contact_user_id, .. } = contact { - for contact_conn_id in store.user_connection_ids(contact_user_id) { + for contact_conn_id in pool.user_connection_ids(contact_user_id) { self.peer .send( contact_conn_id, @@ -1522,7 +1532,11 @@ impl Server { // Update outgoing contact requests of requester let mut update = proto::UpdateContacts::default(); update.outgoing_requests.push(responder_id.to_proto()); - for connection_id in self.store().await.user_connection_ids(requester_id) { + for connection_id in self + .connection_pool() + .await + .user_connection_ids(requester_id) + { self.peer.send(connection_id, update.clone())?; } @@ -1534,7 +1548,11 @@ impl Server { requester_id: requester_id.to_proto(), should_notify: true, }); - for connection_id in self.store().await.user_connection_ids(responder_id) { + for connection_id in self + .connection_pool() + .await + .user_connection_ids(responder_id) + { self.peer.send(connection_id, update.clone())?; } @@ -1563,18 +1581,18 @@ impl Server { .await?; let busy = self.app_state.db.is_user_busy(requester_id).await?; - let store = self.store().await; + let pool = self.connection_pool().await; // Update responder with new contact let mut update = proto::UpdateContacts::default(); if accept { update .contacts - .push(store.contact_for_user(requester_id, false, busy)); + .push(contact_for_user(requester_id, false, busy, &pool)); } update .remove_incoming_requests .push(requester_id.to_proto()); - for connection_id in store.user_connection_ids(responder_id) { + for connection_id in pool.user_connection_ids(responder_id) { self.peer.send(connection_id, update.clone())?; } @@ -1583,12 +1601,12 @@ impl Server { if accept { update .contacts - .push(store.contact_for_user(responder_id, true, busy)); + .push(contact_for_user(responder_id, true, busy, &pool)); } update .remove_outgoing_requests .push(responder_id.to_proto()); - for connection_id in store.user_connection_ids(requester_id) { + for connection_id in pool.user_connection_ids(requester_id) { self.peer.send(connection_id, update.clone())?; } } @@ -1615,7 +1633,11 @@ impl Server { update .remove_outgoing_requests .push(responder_id.to_proto()); - for connection_id in self.store().await.user_connection_ids(requester_id) { + for connection_id in self + .connection_pool() + .await + .user_connection_ids(requester_id) + { self.peer.send(connection_id, update.clone())?; } @@ -1624,7 +1646,11 @@ impl Server { update .remove_incoming_requests .push(requester_id.to_proto()); - for connection_id in self.store().await.user_connection_ids(responder_id) { + for connection_id in self + .connection_pool() + .await + .user_connection_ids(responder_id) + { self.peer.send(connection_id, update.clone())?; } @@ -1678,13 +1704,13 @@ impl Server { Ok(()) } - pub(crate) async fn store(&self) -> StoreGuard<'_> { + pub(crate) async fn connection_pool(&self) -> ConnectionPoolGuard<'_> { #[cfg(test)] tokio::task::yield_now().await; - let guard = self.store.lock().await; + let guard = self.connection_pool.lock().await; #[cfg(test)] tokio::task::yield_now().await; - StoreGuard { + ConnectionPoolGuard { guard, _not_send: PhantomData, } @@ -1692,27 +1718,27 @@ impl Server { pub async fn snapshot<'a>(self: &'a Arc) -> ServerSnapshot<'a> { ServerSnapshot { - store: self.store().await, + connection_pool: self.connection_pool().await, peer: &self.peer, } } } -impl<'a> Deref for StoreGuard<'a> { - type Target = Store; +impl<'a> Deref for ConnectionPoolGuard<'a> { + type Target = ConnectionPool; fn deref(&self) -> &Self::Target { &*self.guard } } -impl<'a> DerefMut for StoreGuard<'a> { +impl<'a> DerefMut for ConnectionPoolGuard<'a> { fn deref_mut(&mut self) -> &mut Self::Target { &mut *self.guard } } -impl<'a> Drop for StoreGuard<'a> { +impl<'a> Drop for ConnectionPoolGuard<'a> { fn drop(&mut self) { #[cfg(test)] self.check_invariants(); @@ -1821,7 +1847,7 @@ pub async fn handle_websocket_request( pub async fn handle_metrics(Extension(server): Extension>) -> Result { let connections = server - .store() + .connection_pool() .await .connections() .filter(|connection| !connection.admin) @@ -1868,6 +1894,53 @@ fn to_tungstenite_message(message: AxumMessage) -> TungsteniteMessage { } } +fn build_initial_contacts_update( + contacts: Vec, + pool: &ConnectionPool, +) -> proto::UpdateContacts { + let mut update = proto::UpdateContacts::default(); + + for contact in contacts { + match contact { + db::Contact::Accepted { + user_id, + should_notify, + busy, + } => { + update + .contacts + .push(contact_for_user(user_id, should_notify, busy, &pool)); + } + db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()), + db::Contact::Incoming { + user_id, + should_notify, + } => update + .incoming_requests + .push(proto::IncomingContactRequest { + requester_id: user_id.to_proto(), + should_notify, + }), + } + } + + update +} + +fn contact_for_user( + user_id: UserId, + should_notify: bool, + busy: bool, + pool: &ConnectionPool, +) -> proto::Contact { + proto::Contact { + user_id: user_id.to_proto(), + online: pool.is_user_online(user_id), + busy, + should_notify, + } +} + pub trait ResultExt { type Ok; diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/connection_pool.rs similarity index 64% rename from crates/collab/src/rpc/store.rs rename to crates/collab/src/rpc/connection_pool.rs index 2bb6d89f401a0274c3ac83b70eaa9cd192c882d1..ac7632f7da2ae6d4d6beb95aeb298d8e409f8d80 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/connection_pool.rs @@ -1,12 +1,12 @@ -use crate::db::{self, UserId}; +use crate::db::UserId; use anyhow::{anyhow, Result}; use collections::{BTreeMap, HashSet}; -use rpc::{proto, ConnectionId}; +use rpc::ConnectionId; use serde::Serialize; use tracing::instrument; #[derive(Default, Serialize)] -pub struct Store { +pub struct ConnectionPool { connections: BTreeMap, connected_users: BTreeMap, } @@ -22,7 +22,7 @@ pub struct Connection { pub admin: bool, } -impl Store { +impl ConnectionPool { #[instrument(skip(self))] pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId, admin: bool) { self.connections @@ -70,55 +70,6 @@ impl Store { .is_empty() } - pub fn build_initial_contacts_update( - &self, - contacts: Vec, - ) -> proto::UpdateContacts { - let mut update = proto::UpdateContacts::default(); - - for contact in contacts { - match contact { - db::Contact::Accepted { - user_id, - should_notify, - busy, - } => { - update - .contacts - .push(self.contact_for_user(user_id, should_notify, busy)); - } - db::Contact::Outgoing { user_id } => { - update.outgoing_requests.push(user_id.to_proto()) - } - db::Contact::Incoming { - user_id, - should_notify, - } => update - .incoming_requests - .push(proto::IncomingContactRequest { - requester_id: user_id.to_proto(), - should_notify, - }), - } - } - - update - } - - pub fn contact_for_user( - &self, - user_id: UserId, - should_notify: bool, - busy: bool, - ) -> proto::Contact { - proto::Contact { - user_id: user_id.to_proto(), - online: self.is_user_online(user_id), - busy, - should_notify, - } - } - #[cfg(test)] pub fn check_invariants(&self) { for (connection_id, connection) in &self.connections { From c3d556d9bdf6a924e07b945c06f882bed93cfbce Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 18 Nov 2022 11:45:42 +0100 Subject: [PATCH 056/109] Don't take an `Arc` in message handlers --- crates/collab/src/rpc.rs | 2675 +++++++++++++++++++------------------- 1 file changed, 1320 insertions(+), 1355 deletions(-) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 175e3604c04acc522348a6f2c92e7fdb53b16599..ba97b09acd1a72b0fb7340c9a4ace2e8b62cffca 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -39,6 +39,7 @@ use rpc::{ use serde::{Serialize, Serializer}; use std::{ any::TypeId, + fmt, future::Future, marker::PhantomData, net::SocketAddr, @@ -67,20 +68,63 @@ lazy_static! { .unwrap(); } -type MessageHandler = Box< - dyn Send + Sync + Fn(Arc, Box, Session) -> BoxFuture<'static, ()>, ->; +type MessageHandler = + Box, Session) -> BoxFuture<'static, ()>>; struct Response { - server: Arc, + peer: Arc, receipt: Receipt, responded: Arc, } +impl Response { + fn send(self, payload: R::Response) -> Result<()> { + self.responded.store(true, SeqCst); + self.peer.respond(self.receipt, payload)?; + Ok(()) + } +} + +#[derive(Clone)] struct Session { user_id: UserId, connection_id: ConnectionId, db: Arc>, + peer: Arc, + connection_pool: Arc>, + live_kit_client: Option>, +} + +impl Session { + async fn db(&self) -> MutexGuard { + #[cfg(test)] + tokio::task::yield_now().await; + let guard = self.db.lock().await; + #[cfg(test)] + tokio::task::yield_now().await; + guard + } + + async fn connection_pool(&self) -> ConnectionPoolGuard<'_> { + #[cfg(test)] + tokio::task::yield_now().await; + let guard = self.connection_pool.lock().await; + #[cfg(test)] + tokio::task::yield_now().await; + ConnectionPoolGuard { + guard, + _not_send: PhantomData, + } + } +} + +impl fmt::Debug for Session { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Session") + .field("user_id", &self.user_id) + .field("connection_id", &self.connection_id) + .finish() + } } struct DbHandle(Arc); @@ -93,17 +137,9 @@ impl Deref for DbHandle { } } -impl Response { - fn send(self, payload: R::Response) -> Result<()> { - self.responded.store(true, SeqCst); - self.server.peer.respond(self.receipt, payload)?; - Ok(()) - } -} - pub struct Server { peer: Arc, - pub(crate) connection_pool: Mutex, + pub(crate) connection_pool: Arc>, app_state: Arc, handlers: HashMap, } @@ -148,76 +184,74 @@ impl Server { }; server - .add_request_handler(Server::ping) - .add_request_handler(Server::create_room) - .add_request_handler(Server::join_room) - .add_message_handler(Server::leave_room) - .add_request_handler(Server::call) - .add_request_handler(Server::cancel_call) - .add_message_handler(Server::decline_call) - .add_request_handler(Server::update_participant_location) - .add_request_handler(Server::share_project) - .add_message_handler(Server::unshare_project) - .add_request_handler(Server::join_project) - .add_message_handler(Server::leave_project) - .add_request_handler(Server::update_project) - .add_request_handler(Server::update_worktree) - .add_message_handler(Server::start_language_server) - .add_message_handler(Server::update_language_server) - .add_request_handler(Server::update_diagnostic_summary) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler( - Server::forward_project_request::, - ) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_request_handler(Server::forward_project_request::) - .add_message_handler(Server::create_buffer_for_peer) - .add_request_handler(Server::update_buffer) - .add_message_handler(Server::update_buffer_file) - .add_message_handler(Server::buffer_reloaded) - .add_message_handler(Server::buffer_saved) - .add_request_handler(Server::save_buffer) - .add_request_handler(Server::get_users) - .add_request_handler(Server::fuzzy_search_users) - .add_request_handler(Server::request_contact) - .add_request_handler(Server::remove_contact) - .add_request_handler(Server::respond_to_contact_request) - .add_request_handler(Server::follow) - .add_message_handler(Server::unfollow) - .add_message_handler(Server::update_followers) - .add_message_handler(Server::update_diff_base) - .add_request_handler(Server::get_private_user_info); + .add_request_handler(ping) + .add_request_handler(create_room) + .add_request_handler(join_room) + .add_message_handler(leave_room) + .add_request_handler(call) + .add_request_handler(cancel_call) + .add_message_handler(decline_call) + .add_request_handler(update_participant_location) + .add_request_handler(share_project) + .add_message_handler(unshare_project) + .add_request_handler(join_project) + .add_message_handler(leave_project) + .add_request_handler(update_project) + .add_request_handler(update_worktree) + .add_message_handler(start_language_server) + .add_message_handler(update_language_server) + .add_request_handler(update_diagnostic_summary) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_request_handler(forward_project_request::) + .add_message_handler(create_buffer_for_peer) + .add_request_handler(update_buffer) + .add_message_handler(update_buffer_file) + .add_message_handler(buffer_reloaded) + .add_message_handler(buffer_saved) + .add_request_handler(save_buffer) + .add_request_handler(get_users) + .add_request_handler(fuzzy_search_users) + .add_request_handler(request_contact) + .add_request_handler(remove_contact) + .add_request_handler(respond_to_contact_request) + .add_request_handler(follow) + .add_message_handler(unfollow) + .add_message_handler(update_followers) + .add_message_handler(update_diff_base) + .add_request_handler(get_private_user_info); Arc::new(server) } fn add_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(Arc, TypedEnvelope, Session) -> Fut, + F: 'static + Send + Sync + Fn(TypedEnvelope, Session) -> Fut, Fut: 'static + Send + Future>, M: EnvelopedMessage, { let prev_handler = self.handlers.insert( TypeId::of::(), - Box::new(move |server, envelope, session| { + Box::new(move |envelope, session| { let envelope = envelope.into_any().downcast::>().unwrap(); let span = info_span!( "handle message", @@ -229,7 +263,7 @@ impl Server { "message received" ); }); - let future = (handler)(server, *envelope, session); + let future = (handler)(*envelope, session); async move { if let Err(error) = future.await { tracing::error!(%error, "error handling message"); @@ -247,34 +281,33 @@ impl Server { fn add_message_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(Arc, M, Session) -> Fut, + F: 'static + Send + Sync + Fn(M, Session) -> Fut, Fut: 'static + Send + Future>, M: EnvelopedMessage, { - self.add_handler(move |server, envelope, session| { - handler(server, envelope.payload, session) - }); + self.add_handler(move |envelope, session| handler(envelope.payload, session)); self } fn add_request_handler(&mut self, handler: F) -> &mut Self where - F: 'static + Send + Sync + Fn(Arc, M, Response, Session) -> Fut, + F: 'static + Send + Sync + Fn(M, Response, Session) -> Fut, Fut: Send + Future>, M: RequestMessage, { let handler = Arc::new(handler); - self.add_handler(move |server, envelope, session| { + self.add_handler(move |envelope, session| { let receipt = envelope.receipt(); let handler = handler.clone(); async move { + let peer = session.peer.clone(); let responded = Arc::new(AtomicBool::default()); let response = Response { - server: server.clone(), + peer: peer.clone(), responded: responded.clone(), receipt, }; - match (handler)(server.clone(), envelope.payload, response, session).await { + match (handler)(envelope.payload, response, session).await { Ok(()) => { if responded.load(std::sync::atomic::Ordering::SeqCst) { Ok(()) @@ -283,7 +316,7 @@ impl Server { } } Err(error) => { - server.peer.respond_with_error( + peer.respond_with_error( receipt, proto::Error { message: error.to_string(), @@ -304,7 +337,7 @@ impl Server { mut send_connection_id: Option>, executor: E, ) -> impl Future> { - let mut this = self.clone(); + let this = self.clone(); let user_id = user.id; let login = user.github_login; let span = info_span!("handle connection", %user_id, %login, %address); @@ -340,7 +373,7 @@ impl Server { ).await?; { - let mut pool = this.connection_pool().await; + let mut pool = this.connection_pool.lock().await; pool.add_connection(connection_id, user_id, user.admin); this.peer.send(connection_id, build_initial_contacts_update(contacts, &pool))?; @@ -356,13 +389,19 @@ impl Server { this.peer.send(connection_id, incoming_call)?; } - this.update_user_contacts(user_id).await?; + let session = Session { + user_id, + connection_id, + db: Arc::new(Mutex::new(DbHandle(this.app_state.db.clone()))), + peer: this.peer.clone(), + connection_pool: this.connection_pool.clone(), + live_kit_client: this.app_state.live_kit_client.clone() + }; + update_user_contacts(user_id, &session).await?; let handle_io = handle_io.fuse(); futures::pin_mut!(handle_io); - let db = Arc::new(Mutex::new(DbHandle(this.app_state.db.clone()))); - // Handlers for foreground messages are pushed into the following `FuturesUnordered`. // This prevents deadlocks when e.g., client A performs a request to client B and // client B performs a request to client A. If both clients stop processing further @@ -390,12 +429,7 @@ impl Server { let span_enter = span.enter(); if let Some(handler) = this.handlers.get(&message.payload_type_id()) { let is_background = message.is_background(); - let session = Session { - user_id, - connection_id, - db: db.clone(), - }; - let handle_message = (handler)(this.clone(), message, session); + let handle_message = (handler)(message, session.clone()); drop(span_enter); let handle_message = handle_message.instrument(span); @@ -417,7 +451,7 @@ impl Server { drop(foreground_message_handlers); tracing::info!(%user_id, %login, %connection_id, %address, "signing out"); - if let Err(error) = this.sign_out(connection_id, user_id).await { + if let Err(error) = sign_out(session).await { tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out"); } @@ -425,40 +459,6 @@ impl Server { }.instrument(span) } - #[instrument(skip(self), err)] - async fn sign_out( - self: &mut Arc, - connection_id: ConnectionId, - user_id: UserId, - ) -> Result<()> { - self.peer.disconnect(connection_id); - let decline_calls = { - let mut pool = self.connection_pool().await; - pool.remove_connection(connection_id)?; - let mut connections = pool.user_connection_ids(user_id); - connections.next().is_none() - }; - - self.leave_room_for_connection(connection_id, user_id) - .await - .trace_err(); - if decline_calls { - if let Some(room) = self - .app_state - .db - .decline_call(None, user_id) - .await - .trace_err() - { - self.room_updated(&room); - } - } - - self.update_user_contacts(user_id).await?; - - Ok(()) - } - pub async fn invite_code_redeemed( self: &Arc, inviter_id: UserId, @@ -466,7 +466,7 @@ impl Server { ) -> Result<()> { if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? { if let Some(code) = &user.invite_code { - let pool = self.connection_pool().await; + let pool = self.connection_pool.lock().await; let invitee_contact = contact_for_user(invitee_id, true, false, &pool); for connection_id in pool.user_connection_ids(inviter_id) { self.peer.send( @@ -492,7 +492,7 @@ impl Server { pub async fn invite_count_updated(self: &Arc, user_id: UserId) -> Result<()> { if let Some(user) = self.app_state.db.get_user_by_id(user_id).await? { if let Some(invite_code) = &user.invite_code { - let pool = self.connection_pool().await; + let pool = self.connection_pool.lock().await; for connection_id in pool.user_connection_ids(user_id) { self.peer.send( connection_id, @@ -510,1360 +510,1194 @@ impl Server { Ok(()) } - async fn ping( - self: Arc, - _: proto::Ping, - response: Response, - _session: Session, - ) -> Result<()> { - response.send(proto::Ack {})?; - Ok(()) + pub async fn snapshot<'a>(self: &'a Arc) -> ServerSnapshot<'a> { + ServerSnapshot { + connection_pool: ConnectionPoolGuard { + guard: self.connection_pool.lock().await, + _not_send: PhantomData, + }, + peer: &self.peer, + } } +} - async fn create_room( - self: Arc, - _request: proto::CreateRoom, - response: Response, - session: Session, - ) -> Result<()> { - let room = self - .app_state - .db - .create_room(session.user_id, session.connection_id) - .await?; - - let live_kit_connection_info = - if let Some(live_kit) = self.app_state.live_kit_client.as_ref() { - if let Some(_) = live_kit - .create_room(room.live_kit_room.clone()) - .await - .trace_err() - { - if let Some(token) = live_kit - .room_token(&room.live_kit_room, &session.connection_id.to_string()) - .trace_err() - { - Some(proto::LiveKitConnectionInfo { - server_url: live_kit.url().into(), - token, - }) - } else { - None - } - } else { - None - } - } else { - None - }; +impl<'a> Deref for ConnectionPoolGuard<'a> { + type Target = ConnectionPool; - response.send(proto::CreateRoomResponse { - room: Some(room), - live_kit_connection_info, - })?; - self.update_user_contacts(session.user_id).await?; - Ok(()) + fn deref(&self) -> &Self::Target { + &*self.guard } +} - async fn join_room( - self: Arc, - request: proto::JoinRoom, - response: Response, - session: Session, - ) -> Result<()> { - let room = self - .app_state - .db - .join_room( - RoomId::from_proto(request.id), - session.user_id, - session.connection_id, - ) - .await?; - for connection_id in self - .connection_pool() - .await - .user_connection_ids(session.user_id) - { - self.peer - .send(connection_id, proto::CallCanceled {}) - .trace_err(); - } +impl<'a> DerefMut for ConnectionPoolGuard<'a> { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut *self.guard + } +} - let live_kit_connection_info = - if let Some(live_kit) = self.app_state.live_kit_client.as_ref() { - if let Some(token) = live_kit - .room_token(&room.live_kit_room, &session.connection_id.to_string()) - .trace_err() - { - Some(proto::LiveKitConnectionInfo { - server_url: live_kit.url().into(), - token, - }) - } else { - None - } - } else { - None - }; +impl<'a> Drop for ConnectionPoolGuard<'a> { + fn drop(&mut self) { + #[cfg(test)] + self.check_invariants(); + } +} - self.room_updated(&room); - response.send(proto::JoinRoomResponse { - room: Some(room), - live_kit_connection_info, - })?; +impl Executor for RealExecutor { + type Sleep = Sleep; - self.update_user_contacts(session.user_id).await?; - Ok(()) + fn spawn_detached>(&self, future: F) { + tokio::task::spawn(future); } - async fn leave_room( - self: Arc, - _message: proto::LeaveRoom, - session: Session, - ) -> Result<()> { - self.leave_room_for_connection(session.connection_id, session.user_id) - .await + fn sleep(&self, duration: Duration) -> Self::Sleep { + tokio::time::sleep(duration) } +} - async fn leave_room_for_connection( - self: &Arc, - leaving_connection_id: ConnectionId, - leaving_user_id: UserId, - ) -> Result<()> { - let mut contacts_to_update = HashSet::default(); - - let Some(left_room) = self.app_state.db.leave_room(leaving_connection_id).await? else { - return Err(anyhow!("no room to leave"))?; - }; - contacts_to_update.insert(leaving_user_id); - - for project in left_room.left_projects.into_values() { - for connection_id in project.connection_ids { - if project.host_user_id == leaving_user_id { - self.peer - .send( - connection_id, - proto::UnshareProject { - project_id: project.id.to_proto(), - }, - ) - .trace_err(); - } else { - self.peer - .send( - connection_id, - proto::RemoveProjectCollaborator { - project_id: project.id.to_proto(), - peer_id: leaving_connection_id.0, - }, - ) - .trace_err(); - } - } - - self.peer - .send( - leaving_connection_id, - proto::UnshareProject { - project_id: project.id.to_proto(), - }, - ) - .trace_err(); +fn broadcast( + sender_id: ConnectionId, + receiver_ids: impl IntoIterator, + mut f: F, +) where + F: FnMut(ConnectionId) -> anyhow::Result<()>, +{ + for receiver_id in receiver_ids { + if receiver_id != sender_id { + f(receiver_id).trace_err(); } + } +} - self.room_updated(&left_room.room); - { - let pool = self.connection_pool().await; - for canceled_user_id in left_room.canceled_calls_to_user_ids { - for connection_id in pool.user_connection_ids(canceled_user_id) { - self.peer - .send(connection_id, proto::CallCanceled {}) - .trace_err(); - } - contacts_to_update.insert(canceled_user_id); - } - } +lazy_static! { + static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version"); +} - for contact_user_id in contacts_to_update { - self.update_user_contacts(contact_user_id).await?; - } +pub struct ProtocolVersion(u32); - if let Some(live_kit) = self.app_state.live_kit_client.as_ref() { - live_kit - .remove_participant( - left_room.room.live_kit_room.clone(), - leaving_connection_id.to_string(), - ) - .await - .trace_err(); +impl Header for ProtocolVersion { + fn name() -> &'static HeaderName { + &ZED_PROTOCOL_VERSION + } - if left_room.room.participants.is_empty() { - live_kit - .delete_room(left_room.room.live_kit_room) - .await - .trace_err(); - } - } + fn decode<'i, I>(values: &mut I) -> Result + where + Self: Sized, + I: Iterator, + { + let version = values + .next() + .ok_or_else(axum::headers::Error::invalid)? + .to_str() + .map_err(|_| axum::headers::Error::invalid())? + .parse() + .map_err(|_| axum::headers::Error::invalid())?; + Ok(Self(version)) + } - Ok(()) + fn encode>(&self, values: &mut E) { + values.extend([self.0.to_string().parse().unwrap()]); } +} - async fn call( - self: Arc, - request: proto::Call, - response: Response, - session: Session, - ) -> Result<()> { - let room_id = RoomId::from_proto(request.room_id); - let calling_user_id = session.user_id; - let calling_connection_id = session.connection_id; - let called_user_id = UserId::from_proto(request.called_user_id); - let initial_project_id = request.initial_project_id.map(ProjectId::from_proto); - if !self - .app_state - .db - .has_contact(calling_user_id, called_user_id) - .await? - { - return Err(anyhow!("cannot call a user who isn't a contact"))?; +pub fn routes(server: Arc) -> Router { + Router::new() + .route("/rpc", get(handle_websocket_request)) + .layer( + ServiceBuilder::new() + .layer(Extension(server.app_state.clone())) + .layer(middleware::from_fn(auth::validate_header)), + ) + .route("/metrics", get(handle_metrics)) + .layer(Extension(server)) +} + +pub async fn handle_websocket_request( + TypedHeader(ProtocolVersion(protocol_version)): TypedHeader, + ConnectInfo(socket_address): ConnectInfo, + Extension(server): Extension>, + Extension(user): Extension, + ws: WebSocketUpgrade, +) -> axum::response::Response { + if protocol_version != rpc::PROTOCOL_VERSION { + return ( + StatusCode::UPGRADE_REQUIRED, + "client must be upgraded".to_string(), + ) + .into_response(); + } + let socket_address = socket_address.to_string(); + ws.on_upgrade(move |socket| { + use util::ResultExt; + let socket = socket + .map_ok(to_tungstenite_message) + .err_into() + .with(|message| async move { Ok(to_axum_message(message)) }); + let connection = Connection::new(Box::pin(socket)); + async move { + server + .handle_connection(connection, socket_address, user, None, RealExecutor) + .await + .log_err(); } + }) +} - let (room, incoming_call) = self - .app_state - .db - .call( - room_id, - calling_user_id, - calling_connection_id, - called_user_id, - initial_project_id, - ) - .await?; - self.room_updated(&room); - self.update_user_contacts(called_user_id).await?; +pub async fn handle_metrics(Extension(server): Extension>) -> Result { + let connections = server + .connection_pool + .lock() + .await + .connections() + .filter(|connection| !connection.admin) + .count(); - let mut calls = self - .connection_pool() - .await - .user_connection_ids(called_user_id) - .map(|connection_id| self.peer.request(connection_id, incoming_call.clone())) - .collect::>(); - - while let Some(call_response) = calls.next().await { - match call_response.as_ref() { - Ok(_) => { - response.send(proto::Ack {})?; - return Ok(()); - } - Err(_) => { - call_response.trace_err(); - } - } - } + METRIC_CONNECTIONS.set(connections as _); - let room = self - .app_state - .db - .call_failed(room_id, called_user_id) - .await?; - self.room_updated(&room); - self.update_user_contacts(called_user_id).await?; + let shared_projects = server.app_state.db.project_count_excluding_admins().await?; + METRIC_SHARED_PROJECTS.set(shared_projects as _); - Err(anyhow!("failed to ring user"))? - } + let encoder = prometheus::TextEncoder::new(); + let metric_families = prometheus::gather(); + let encoded_metrics = encoder + .encode_to_string(&metric_families) + .map_err(|err| anyhow!("{}", err))?; + Ok(encoded_metrics) +} - async fn cancel_call( - self: Arc, - request: proto::CancelCall, - response: Response, - session: Session, - ) -> Result<()> { - let called_user_id = UserId::from_proto(request.called_user_id); - let room_id = RoomId::from_proto(request.room_id); - let room = self - .app_state - .db - .cancel_call(Some(room_id), session.connection_id, called_user_id) - .await?; - for connection_id in self - .connection_pool() +#[instrument(err)] +async fn sign_out(session: Session) -> Result<()> { + session.peer.disconnect(session.connection_id); + let decline_calls = { + let mut pool = session.connection_pool().await; + pool.remove_connection(session.connection_id)?; + let mut connections = pool.user_connection_ids(session.user_id); + connections.next().is_none() + }; + + leave_room_for_session(&session).await.trace_err(); + if decline_calls { + if let Some(room) = session + .db() .await - .user_connection_ids(called_user_id) + .decline_call(None, session.user_id) + .await + .trace_err() { - self.peer - .send(connection_id, proto::CallCanceled {}) - .trace_err(); + room_updated(&room, &session); } - self.room_updated(&room); - response.send(proto::Ack {})?; - - self.update_user_contacts(called_user_id).await?; - Ok(()) } - async fn decline_call( - self: Arc, - message: proto::DeclineCall, - session: Session, - ) -> Result<()> { - let room_id = RoomId::from_proto(message.room_id); - let room = self - .app_state - .db - .decline_call(Some(room_id), session.user_id) - .await?; - for connection_id in self - .connection_pool() + update_user_contacts(session.user_id, &session).await?; + + Ok(()) +} + +async fn ping(_: proto::Ping, response: Response, _session: Session) -> Result<()> { + response.send(proto::Ack {})?; + Ok(()) +} + +async fn create_room( + _request: proto::CreateRoom, + response: Response, + session: Session, +) -> Result<()> { + let room = session + .db() + .await + .create_room(session.user_id, session.connection_id) + .await?; + + let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() { + if let Some(_) = live_kit + .create_room(room.live_kit_room.clone()) .await - .user_connection_ids(session.user_id) + .trace_err() { - self.peer - .send(connection_id, proto::CallCanceled {}) - .trace_err(); - } - self.room_updated(&room); - self.update_user_contacts(session.user_id).await?; - Ok(()) - } - - async fn update_participant_location( - self: Arc, - request: proto::UpdateParticipantLocation, - response: Response, - session: Session, - ) -> Result<()> { - let room_id = RoomId::from_proto(request.room_id); - let location = request - .location - .ok_or_else(|| anyhow!("invalid location"))?; - let room = self - .app_state - .db - .update_room_participant_location(room_id, session.connection_id, location) - .await?; - self.room_updated(&room); - response.send(proto::Ack {})?; - Ok(()) - } - - fn room_updated(&self, room: &proto::Room) { - for participant in &room.participants { - self.peer - .send( - ConnectionId(participant.peer_id), - proto::RoomUpdated { - room: Some(room.clone()), - }, - ) - .trace_err(); - } - } - - async fn share_project( - self: Arc, - request: proto::ShareProject, - response: Response, - session: Session, - ) -> Result<()> { - let (project_id, room) = self - .app_state - .db - .share_project( - RoomId::from_proto(request.room_id), - session.connection_id, - &request.worktrees, - ) - .await?; - response.send(proto::ShareProjectResponse { - project_id: project_id.to_proto(), - })?; - self.room_updated(&room); - - Ok(()) - } - - async fn unshare_project( - self: Arc, - message: proto::UnshareProject, - session: Session, - ) -> Result<()> { - let project_id = ProjectId::from_proto(message.project_id); - - let (room, guest_connection_ids) = self - .app_state - .db - .unshare_project(project_id, session.connection_id) - .await?; - - broadcast(session.connection_id, guest_connection_ids, |conn_id| { - self.peer.send(conn_id, message.clone()) - }); - self.room_updated(&room); - - Ok(()) - } - - async fn update_user_contacts(self: &Arc, user_id: UserId) -> Result<()> { - let contacts = self.app_state.db.get_contacts(user_id).await?; - let busy = self.app_state.db.is_user_busy(user_id).await?; - let pool = self.connection_pool().await; - let updated_contact = contact_for_user(user_id, false, busy, &pool); - for contact in contacts { - if let db::Contact::Accepted { - user_id: contact_user_id, - .. - } = contact + if let Some(token) = live_kit + .room_token(&room.live_kit_room, &session.connection_id.to_string()) + .trace_err() { - for contact_conn_id in pool.user_connection_ids(contact_user_id) { - self.peer - .send( - contact_conn_id, - proto::UpdateContacts { - contacts: vec![updated_contact.clone()], - remove_contacts: Default::default(), - incoming_requests: Default::default(), - remove_incoming_requests: Default::default(), - outgoing_requests: Default::default(), - remove_outgoing_requests: Default::default(), - }, - ) - .trace_err(); - } - } - } - Ok(()) - } - - async fn join_project( - self: Arc, - request: proto::JoinProject, - response: Response, - session: Session, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let guest_user_id = session.user_id; - - tracing::info!(%project_id, "join project"); - - let (project, replica_id) = self - .app_state - .db - .join_project(project_id, session.connection_id) - .await?; - - let collaborators = project - .collaborators - .iter() - .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32) - .map(|collaborator| proto::Collaborator { - peer_id: collaborator.connection_id as u32, - replica_id: collaborator.replica_id.0 as u32, - user_id: collaborator.user_id.to_proto(), - }) - .collect::>(); - let worktrees = project - .worktrees - .iter() - .map(|(id, worktree)| proto::WorktreeMetadata { - id: id.to_proto(), - root_name: worktree.root_name.clone(), - visible: worktree.visible, - abs_path: worktree.abs_path.clone(), - }) - .collect::>(); - - for collaborator in &collaborators { - self.peer - .send( - ConnectionId(collaborator.peer_id), - proto::AddProjectCollaborator { - project_id: project_id.to_proto(), - collaborator: Some(proto::Collaborator { - peer_id: session.connection_id.0, - replica_id: replica_id.0 as u32, - user_id: guest_user_id.to_proto(), - }), - }, - ) - .trace_err(); - } - - // First, we send the metadata associated with each worktree. - response.send(proto::JoinProjectResponse { - worktrees: worktrees.clone(), - replica_id: replica_id.0 as u32, - collaborators: collaborators.clone(), - language_servers: project.language_servers.clone(), - })?; - - for (worktree_id, worktree) in project.worktrees { - #[cfg(any(test, feature = "test-support"))] - const MAX_CHUNK_SIZE: usize = 2; - #[cfg(not(any(test, feature = "test-support")))] - const MAX_CHUNK_SIZE: usize = 256; - - // Stream this worktree's entries. - let message = proto::UpdateWorktree { - project_id: project_id.to_proto(), - worktree_id: worktree_id.to_proto(), - abs_path: worktree.abs_path.clone(), - root_name: worktree.root_name, - updated_entries: worktree.entries, - removed_entries: Default::default(), - scan_id: worktree.scan_id, - is_last_update: worktree.is_complete, - }; - for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) { - self.peer.send(session.connection_id, update.clone())?; - } - - // Stream this worktree's diagnostics. - for summary in worktree.diagnostic_summaries { - self.peer.send( - session.connection_id, - proto::UpdateDiagnosticSummary { - project_id: project_id.to_proto(), - worktree_id: worktree.id.to_proto(), - summary: Some(summary), - }, - )?; + Some(proto::LiveKitConnectionInfo { + server_url: live_kit.url().into(), + token, + }) + } else { + None } + } else { + None } + } else { + None + }; + + response.send(proto::CreateRoomResponse { + room: Some(room), + live_kit_connection_info, + })?; + update_user_contacts(session.user_id, &session).await?; + Ok(()) +} - for language_server in &project.language_servers { - self.peer.send( - session.connection_id, - proto::UpdateLanguageServer { - project_id: project_id.to_proto(), - language_server_id: language_server.id, - variant: Some( - proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated( - proto::LspDiskBasedDiagnosticsUpdated {}, - ), - ), - }, - )?; - } - - Ok(()) +async fn join_room( + request: proto::JoinRoom, + response: Response, + session: Session, +) -> Result<()> { + let room = session + .db() + .await + .join_room( + RoomId::from_proto(request.id), + session.user_id, + session.connection_id, + ) + .await?; + for connection_id in session + .connection_pool() + .await + .user_connection_ids(session.user_id) + { + session + .peer + .send(connection_id, proto::CallCanceled {}) + .trace_err(); } - async fn leave_project( - self: Arc, - request: proto::LeaveProject, - session: Session, - ) -> Result<()> { - let sender_id = session.connection_id; - let project_id = ProjectId::from_proto(request.project_id); - let project; + let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() { + if let Some(token) = live_kit + .room_token(&room.live_kit_room, &session.connection_id.to_string()) + .trace_err() { - project = self - .app_state - .db - .leave_project(project_id, sender_id) - .await?; - tracing::info!( - %project_id, - host_user_id = %project.host_user_id, - host_connection_id = %project.host_connection_id, - "leave project" - ); - - broadcast(sender_id, project.connection_ids, |conn_id| { - self.peer.send( - conn_id, - proto::RemoveProjectCollaborator { - project_id: project_id.to_proto(), - peer_id: sender_id.0, - }, - ) - }); + Some(proto::LiveKitConnectionInfo { + server_url: live_kit.url().into(), + token, + }) + } else { + None } + } else { + None + }; + + room_updated(&room, &session); + response.send(proto::JoinRoomResponse { + room: Some(room), + live_kit_connection_info, + })?; + + update_user_contacts(session.user_id, &session).await?; + Ok(()) +} - Ok(()) - } - - async fn update_project( - self: Arc, - request: proto::UpdateProject, - response: Response, - session: Session, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let (room, guest_connection_ids) = self - .app_state - .db - .update_project(project_id, session.connection_id, &request.worktrees) - .await?; - broadcast( - session.connection_id, - guest_connection_ids, - |connection_id| { - self.peer - .forward_send(session.connection_id, connection_id, request.clone()) - }, - ); - self.room_updated(&room); - response.send(proto::Ack {})?; +async fn leave_room(_message: proto::LeaveRoom, session: Session) -> Result<()> { + leave_room_for_session(&session).await +} - Ok(()) +async fn call( + request: proto::Call, + response: Response, + session: Session, +) -> Result<()> { + let room_id = RoomId::from_proto(request.room_id); + let calling_user_id = session.user_id; + let calling_connection_id = session.connection_id; + let called_user_id = UserId::from_proto(request.called_user_id); + let initial_project_id = request.initial_project_id.map(ProjectId::from_proto); + if !session + .db() + .await + .has_contact(calling_user_id, called_user_id) + .await? + { + return Err(anyhow!("cannot call a user who isn't a contact"))?; } - async fn update_worktree( - self: Arc, - request: proto::UpdateWorktree, - response: Response, - session: Session, - ) -> Result<()> { - let guest_connection_ids = self - .app_state - .db - .update_worktree(&request, session.connection_id) - .await?; + let (room, incoming_call) = session + .db() + .await + .call( + room_id, + calling_user_id, + calling_connection_id, + called_user_id, + initial_project_id, + ) + .await?; + room_updated(&room, &session); + update_user_contacts(called_user_id, &session).await?; - broadcast( - session.connection_id, - guest_connection_ids, - |connection_id| { - self.peer - .forward_send(session.connection_id, connection_id, request.clone()) - }, - ); - response.send(proto::Ack {})?; - Ok(()) + let mut calls = session + .connection_pool() + .await + .user_connection_ids(called_user_id) + .map(|connection_id| session.peer.request(connection_id, incoming_call.clone())) + .collect::>(); + + while let Some(call_response) = calls.next().await { + match call_response.as_ref() { + Ok(_) => { + response.send(proto::Ack {})?; + return Ok(()); + } + Err(_) => { + call_response.trace_err(); + } + } } - async fn update_diagnostic_summary( - self: Arc, - request: proto::UpdateDiagnosticSummary, - response: Response, - session: Session, - ) -> Result<()> { - let guest_connection_ids = self - .app_state - .db - .update_diagnostic_summary(&request, session.connection_id) - .await?; - - broadcast( - session.connection_id, - guest_connection_ids, - |connection_id| { - self.peer - .forward_send(session.connection_id, connection_id, request.clone()) - }, - ); - - response.send(proto::Ack {})?; - Ok(()) - } + let room = session + .db() + .await + .call_failed(room_id, called_user_id) + .await?; + room_updated(&room, &session); + update_user_contacts(called_user_id, &session).await?; - async fn start_language_server( - self: Arc, - request: proto::StartLanguageServer, - session: Session, - ) -> Result<()> { - let guest_connection_ids = self - .app_state - .db - .start_language_server(&request, session.connection_id) - .await?; + Err(anyhow!("failed to ring user"))? +} - broadcast( - session.connection_id, - guest_connection_ids, - |connection_id| { - self.peer - .forward_send(session.connection_id, connection_id, request.clone()) - }, - ); - Ok(()) +async fn cancel_call( + request: proto::CancelCall, + response: Response, + session: Session, +) -> Result<()> { + let called_user_id = UserId::from_proto(request.called_user_id); + let room_id = RoomId::from_proto(request.room_id); + let room = session + .db() + .await + .cancel_call(Some(room_id), session.connection_id, called_user_id) + .await?; + for connection_id in session + .connection_pool() + .await + .user_connection_ids(called_user_id) + { + session + .peer + .send(connection_id, proto::CallCanceled {}) + .trace_err(); } + room_updated(&room, &session); + response.send(proto::Ack {})?; - async fn update_language_server( - self: Arc, - request: proto::UpdateLanguageServer, - session: Session, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let project_connection_ids = self - .app_state - .db - .project_connection_ids(project_id, session.connection_id) - .await?; - broadcast( - session.connection_id, - project_connection_ids, - |connection_id| { - self.peer - .forward_send(session.connection_id, connection_id, request.clone()) - }, - ); - Ok(()) - } + update_user_contacts(called_user_id, &session).await?; + Ok(()) +} - async fn forward_project_request( - self: Arc, - request: T, - response: Response, - session: Session, - ) -> Result<()> - where - T: EntityMessage + RequestMessage, +async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> { + let room_id = RoomId::from_proto(message.room_id); + let room = session + .db() + .await + .decline_call(Some(room_id), session.user_id) + .await?; + for connection_id in session + .connection_pool() + .await + .user_connection_ids(session.user_id) { - let project_id = ProjectId::from_proto(request.remote_entity_id()); - let collaborators = self - .app_state - .db - .project_collaborators(project_id, session.connection_id) - .await?; - let host = collaborators - .iter() - .find(|collaborator| collaborator.is_host) - .ok_or_else(|| anyhow!("host not found"))?; - - let payload = self + session .peer - .forward_request( - session.connection_id, - ConnectionId(host.connection_id as u32), - request, - ) - .await?; - - response.send(payload)?; - Ok(()) + .send(connection_id, proto::CallCanceled {}) + .trace_err(); } + room_updated(&room, &session); + update_user_contacts(session.user_id, &session).await?; + Ok(()) +} - async fn save_buffer( - self: Arc, - request: proto::SaveBuffer, - response: Response, - session: Session, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let collaborators = self - .app_state - .db - .project_collaborators(project_id, session.connection_id) - .await?; - let host = collaborators - .into_iter() - .find(|collaborator| collaborator.is_host) - .ok_or_else(|| anyhow!("host not found"))?; - let host_connection_id = ConnectionId(host.connection_id as u32); - let response_payload = self - .peer - .forward_request(session.connection_id, host_connection_id, request.clone()) - .await?; - - let mut collaborators = self - .app_state - .db - .project_collaborators(project_id, session.connection_id) - .await?; - collaborators - .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32); - let project_connection_ids = collaborators - .into_iter() - .map(|collaborator| ConnectionId(collaborator.connection_id as u32)); - broadcast(host_connection_id, project_connection_ids, |conn_id| { - self.peer - .forward_send(host_connection_id, conn_id, response_payload.clone()) - }); - response.send(response_payload)?; - Ok(()) - } +async fn update_participant_location( + request: proto::UpdateParticipantLocation, + response: Response, + session: Session, +) -> Result<()> { + let room_id = RoomId::from_proto(request.room_id); + let location = request + .location + .ok_or_else(|| anyhow!("invalid location"))?; + let room = session + .db() + .await + .update_room_participant_location(room_id, session.connection_id, location) + .await?; + room_updated(&room, &session); + response.send(proto::Ack {})?; + Ok(()) +} - async fn create_buffer_for_peer( - self: Arc, - request: proto::CreateBufferForPeer, - session: Session, - ) -> Result<()> { - self.peer.forward_send( +async fn share_project( + request: proto::ShareProject, + response: Response, + session: Session, +) -> Result<()> { + let (project_id, room) = session + .db() + .await + .share_project( + RoomId::from_proto(request.room_id), session.connection_id, - ConnectionId(request.peer_id), - request, - )?; - Ok(()) - } + &request.worktrees, + ) + .await?; + response.send(proto::ShareProjectResponse { + project_id: project_id.to_proto(), + })?; + room_updated(&room, &session); - async fn update_buffer( - self: Arc, - request: proto::UpdateBuffer, - response: Response, - session: Session, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let project_connection_ids = self - .app_state - .db - .project_connection_ids(project_id, session.connection_id) - .await?; + Ok(()) +} - broadcast( - session.connection_id, - project_connection_ids, - |connection_id| { - self.peer - .forward_send(session.connection_id, connection_id, request.clone()) - }, - ); - response.send(proto::Ack {})?; - Ok(()) - } +async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(message.project_id); - async fn update_buffer_file( - self: Arc, - request: proto::UpdateBufferFile, - session: Session, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let project_connection_ids = self - .app_state - .db - .project_connection_ids(project_id, session.connection_id) - .await?; + let (room, guest_connection_ids) = session + .db() + .await + .unshare_project(project_id, session.connection_id) + .await?; - broadcast( - session.connection_id, - project_connection_ids, - |connection_id| { - self.peer - .forward_send(session.connection_id, connection_id, request.clone()) - }, - ); - Ok(()) - } + broadcast(session.connection_id, guest_connection_ids, |conn_id| { + session.peer.send(conn_id, message.clone()) + }); + room_updated(&room, &session); - async fn buffer_reloaded( - self: Arc, - request: proto::BufferReloaded, - session: Session, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let project_connection_ids = self - .app_state - .db - .project_connection_ids(project_id, session.connection_id) - .await?; - broadcast( - session.connection_id, - project_connection_ids, - |connection_id| { - self.peer - .forward_send(session.connection_id, connection_id, request.clone()) - }, - ); - Ok(()) - } + Ok(()) +} - async fn buffer_saved( - self: Arc, - request: proto::BufferSaved, - session: Session, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let project_connection_ids = self - .app_state - .db - .project_connection_ids(project_id, session.connection_id) - .await?; - broadcast( - session.connection_id, - project_connection_ids, - |connection_id| { - self.peer - .forward_send(session.connection_id, connection_id, request.clone()) - }, - ); - Ok(()) - } +async fn join_project( + request: proto::JoinProject, + response: Response, + session: Session, +) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let guest_user_id = session.user_id; - async fn follow( - self: Arc, - request: proto::Follow, - response: Response, - session: Session, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let leader_id = ConnectionId(request.leader_id); - let follower_id = session.connection_id; - let project_connection_ids = self - .app_state - .db - .project_connection_ids(project_id, session.connection_id) - .await?; + tracing::info!(%project_id, "join project"); - if !project_connection_ids.contains(&leader_id) { - Err(anyhow!("no such peer"))?; - } + let (project, replica_id) = session + .db() + .await + .join_project(project_id, session.connection_id) + .await?; + + let collaborators = project + .collaborators + .iter() + .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32) + .map(|collaborator| proto::Collaborator { + peer_id: collaborator.connection_id as u32, + replica_id: collaborator.replica_id.0 as u32, + user_id: collaborator.user_id.to_proto(), + }) + .collect::>(); + let worktrees = project + .worktrees + .iter() + .map(|(id, worktree)| proto::WorktreeMetadata { + id: id.to_proto(), + root_name: worktree.root_name.clone(), + visible: worktree.visible, + abs_path: worktree.abs_path.clone(), + }) + .collect::>(); - let mut response_payload = self + for collaborator in &collaborators { + session .peer - .forward_request(session.connection_id, leader_id, request) - .await?; - response_payload - .views - .retain(|view| view.leader_id != Some(follower_id.0)); - response.send(response_payload)?; - Ok(()) + .send( + ConnectionId(collaborator.peer_id), + proto::AddProjectCollaborator { + project_id: project_id.to_proto(), + collaborator: Some(proto::Collaborator { + peer_id: session.connection_id.0, + replica_id: replica_id.0 as u32, + user_id: guest_user_id.to_proto(), + }), + }, + ) + .trace_err(); } - async fn unfollow(self: Arc, request: proto::Unfollow, session: Session) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let leader_id = ConnectionId(request.leader_id); - let project_connection_ids = self - .app_state - .db - .project_connection_ids(project_id, session.connection_id) - .await?; - if !project_connection_ids.contains(&leader_id) { - Err(anyhow!("no such peer"))?; - } - self.peer - .forward_send(session.connection_id, leader_id, request)?; - Ok(()) - } + // First, we send the metadata associated with each worktree. + response.send(proto::JoinProjectResponse { + worktrees: worktrees.clone(), + replica_id: replica_id.0 as u32, + collaborators: collaborators.clone(), + language_servers: project.language_servers.clone(), + })?; - async fn update_followers( - self: Arc, - request: proto::UpdateFollowers, - session: Session, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let project_connection_ids = session - .db - .lock() - .await - .project_connection_ids(project_id, session.connection_id) - .await?; + for (worktree_id, worktree) in project.worktrees { + #[cfg(any(test, feature = "test-support"))] + const MAX_CHUNK_SIZE: usize = 2; + #[cfg(not(any(test, feature = "test-support")))] + const MAX_CHUNK_SIZE: usize = 256; - let leader_id = request.variant.as_ref().and_then(|variant| match variant { - proto::update_followers::Variant::CreateView(payload) => payload.leader_id, - proto::update_followers::Variant::UpdateView(payload) => payload.leader_id, - proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id, - }); - for follower_id in &request.follower_ids { - let follower_id = ConnectionId(*follower_id); - if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id { - self.peer - .forward_send(session.connection_id, follower_id, request.clone())?; - } + // Stream this worktree's entries. + let message = proto::UpdateWorktree { + project_id: project_id.to_proto(), + worktree_id: worktree_id.to_proto(), + abs_path: worktree.abs_path.clone(), + root_name: worktree.root_name, + updated_entries: worktree.entries, + removed_entries: Default::default(), + scan_id: worktree.scan_id, + is_last_update: worktree.is_complete, + }; + for update in proto::split_worktree_update(message, MAX_CHUNK_SIZE) { + session.peer.send(session.connection_id, update.clone())?; } - Ok(()) - } - async fn get_users( - self: Arc, - request: proto::GetUsers, - response: Response, - _session: Session, - ) -> Result<()> { - let user_ids = request - .user_ids - .into_iter() - .map(UserId::from_proto) - .collect(); - let users = self - .app_state - .db - .get_users_by_ids(user_ids) - .await? - .into_iter() - .map(|user| proto::User { - id: user.id.to_proto(), - avatar_url: format!("https://github.com/{}.png?size=128", user.github_login), - github_login: user.github_login, - }) - .collect(); - response.send(proto::UsersResponse { users })?; - Ok(()) + // Stream this worktree's diagnostics. + for summary in worktree.diagnostic_summaries { + session.peer.send( + session.connection_id, + proto::UpdateDiagnosticSummary { + project_id: project_id.to_proto(), + worktree_id: worktree.id.to_proto(), + summary: Some(summary), + }, + )?; + } } - async fn fuzzy_search_users( - self: Arc, - request: proto::FuzzySearchUsers, - response: Response, - session: Session, - ) -> Result<()> { - let query = request.query; - let db = &self.app_state.db; - let users = match query.len() { - 0 => vec![], - 1 | 2 => db - .get_user_by_github_account(&query, None) - .await? - .into_iter() - .collect(), - _ => db.fuzzy_search_users(&query, 10).await?, - }; - let users = users - .into_iter() - .filter(|user| user.id != session.user_id) - .map(|user| proto::User { - id: user.id.to_proto(), - avatar_url: format!("https://github.com/{}.png?size=128", user.github_login), - github_login: user.github_login, - }) - .collect(); - response.send(proto::UsersResponse { users })?; - Ok(()) + for language_server in &project.language_servers { + session.peer.send( + session.connection_id, + proto::UpdateLanguageServer { + project_id: project_id.to_proto(), + language_server_id: language_server.id, + variant: Some( + proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated( + proto::LspDiskBasedDiagnosticsUpdated {}, + ), + ), + }, + )?; } - async fn request_contact( - self: Arc, - request: proto::RequestContact, - response: Response, - session: Session, - ) -> Result<()> { - let requester_id = session.user_id; - let responder_id = UserId::from_proto(request.responder_id); - if requester_id == responder_id { - return Err(anyhow!("cannot add yourself as a contact"))?; - } - - self.app_state - .db - .send_contact_request(requester_id, responder_id) - .await?; - - // Update outgoing contact requests of requester - let mut update = proto::UpdateContacts::default(); - update.outgoing_requests.push(responder_id.to_proto()); - for connection_id in self - .connection_pool() - .await - .user_connection_ids(requester_id) - { - self.peer.send(connection_id, update.clone())?; - } + Ok(()) +} - // Update incoming contact requests of responder - let mut update = proto::UpdateContacts::default(); - update - .incoming_requests - .push(proto::IncomingContactRequest { - requester_id: requester_id.to_proto(), - should_notify: true, - }); - for connection_id in self - .connection_pool() +async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> { + let sender_id = session.connection_id; + let project_id = ProjectId::from_proto(request.project_id); + let project; + { + project = session + .db() .await - .user_connection_ids(responder_id) - { - self.peer.send(connection_id, update.clone())?; - } + .leave_project(project_id, sender_id) + .await?; + tracing::info!( + %project_id, + host_user_id = %project.host_user_id, + host_connection_id = %project.host_connection_id, + "leave project" + ); - response.send(proto::Ack {})?; - Ok(()) + broadcast(sender_id, project.connection_ids, |conn_id| { + session.peer.send( + conn_id, + proto::RemoveProjectCollaborator { + project_id: project_id.to_proto(), + peer_id: sender_id.0, + }, + ) + }); } - async fn respond_to_contact_request( - self: Arc, - request: proto::RespondToContactRequest, - response: Response, - session: Session, - ) -> Result<()> { - let responder_id = session.user_id; - let requester_id = UserId::from_proto(request.requester_id); - if request.response == proto::ContactRequestResponse::Dismiss as i32 { - self.app_state - .db - .dismiss_contact_notification(responder_id, requester_id) - .await?; - } else { - let accept = request.response == proto::ContactRequestResponse::Accept as i32; - self.app_state - .db - .respond_to_contact_request(responder_id, requester_id, accept) - .await?; - let busy = self.app_state.db.is_user_busy(requester_id).await?; - - let pool = self.connection_pool().await; - // Update responder with new contact - let mut update = proto::UpdateContacts::default(); - if accept { - update - .contacts - .push(contact_for_user(requester_id, false, busy, &pool)); - } - update - .remove_incoming_requests - .push(requester_id.to_proto()); - for connection_id in pool.user_connection_ids(responder_id) { - self.peer.send(connection_id, update.clone())?; - } + Ok(()) +} - // Update requester with new contact - let mut update = proto::UpdateContacts::default(); - if accept { - update - .contacts - .push(contact_for_user(responder_id, true, busy, &pool)); - } - update - .remove_outgoing_requests - .push(responder_id.to_proto()); - for connection_id in pool.user_connection_ids(requester_id) { - self.peer.send(connection_id, update.clone())?; - } - } +async fn update_project( + request: proto::UpdateProject, + response: Response, + session: Session, +) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let (room, guest_connection_ids) = session + .db() + .await + .update_project(project_id, session.connection_id, &request.worktrees) + .await?; + broadcast( + session.connection_id, + guest_connection_ids, + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + room_updated(&room, &session); + response.send(proto::Ack {})?; - response.send(proto::Ack {})?; - Ok(()) - } + Ok(()) +} - async fn remove_contact( - self: Arc, - request: proto::RemoveContact, - response: Response, - session: Session, - ) -> Result<()> { - let requester_id = session.user_id; - let responder_id = UserId::from_proto(request.user_id); - self.app_state - .db - .remove_contact(requester_id, responder_id) - .await?; +async fn update_worktree( + request: proto::UpdateWorktree, + response: Response, + session: Session, +) -> Result<()> { + let guest_connection_ids = session + .db() + .await + .update_worktree(&request, session.connection_id) + .await?; + + broadcast( + session.connection_id, + guest_connection_ids, + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + response.send(proto::Ack {})?; + Ok(()) +} - // Update outgoing contact requests of requester - let mut update = proto::UpdateContacts::default(); - update - .remove_outgoing_requests - .push(responder_id.to_proto()); - for connection_id in self - .connection_pool() - .await - .user_connection_ids(requester_id) - { - self.peer.send(connection_id, update.clone())?; - } +async fn update_diagnostic_summary( + request: proto::UpdateDiagnosticSummary, + response: Response, + session: Session, +) -> Result<()> { + let guest_connection_ids = session + .db() + .await + .update_diagnostic_summary(&request, session.connection_id) + .await?; + + broadcast( + session.connection_id, + guest_connection_ids, + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); - // Update incoming contact requests of responder - let mut update = proto::UpdateContacts::default(); - update - .remove_incoming_requests - .push(requester_id.to_proto()); - for connection_id in self - .connection_pool() - .await - .user_connection_ids(responder_id) - { - self.peer.send(connection_id, update.clone())?; - } + response.send(proto::Ack {})?; + Ok(()) +} - response.send(proto::Ack {})?; - Ok(()) - } +async fn start_language_server( + request: proto::StartLanguageServer, + session: Session, +) -> Result<()> { + let guest_connection_ids = session + .db() + .await + .start_language_server(&request, session.connection_id) + .await?; + + broadcast( + session.connection_id, + guest_connection_ids, + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + Ok(()) +} - async fn update_diff_base( - self: Arc, - request: proto::UpdateDiffBase, - session: Session, - ) -> Result<()> { - let project_id = ProjectId::from_proto(request.project_id); - let project_connection_ids = self - .app_state - .db - .project_connection_ids(project_id, session.connection_id) - .await?; - broadcast( - session.connection_id, - project_connection_ids, - |connection_id| { - self.peer - .forward_send(session.connection_id, connection_id, request.clone()) - }, - ); - Ok(()) - } +async fn update_language_server( + request: proto::UpdateLanguageServer, + session: Session, +) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + broadcast( + session.connection_id, + project_connection_ids, + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + Ok(()) +} - async fn get_private_user_info( - self: Arc, - _request: proto::GetPrivateUserInfo, - response: Response, - session: Session, - ) -> Result<()> { - let metrics_id = self - .app_state - .db - .get_user_metrics_id(session.user_id) - .await?; - let user = self - .app_state - .db - .get_user_by_id(session.user_id) - .await? - .ok_or_else(|| anyhow!("user not found"))?; - response.send(proto::GetPrivateUserInfoResponse { - metrics_id, - staff: user.admin, - })?; - Ok(()) - } +async fn forward_project_request( + request: T, + response: Response, + session: Session, +) -> Result<()> +where + T: EntityMessage + RequestMessage, +{ + let project_id = ProjectId::from_proto(request.remote_entity_id()); + let collaborators = session + .db() + .await + .project_collaborators(project_id, session.connection_id) + .await?; + let host = collaborators + .iter() + .find(|collaborator| collaborator.is_host) + .ok_or_else(|| anyhow!("host not found"))?; + + let payload = session + .peer + .forward_request( + session.connection_id, + ConnectionId(host.connection_id as u32), + request, + ) + .await?; - pub(crate) async fn connection_pool(&self) -> ConnectionPoolGuard<'_> { - #[cfg(test)] - tokio::task::yield_now().await; - let guard = self.connection_pool.lock().await; - #[cfg(test)] - tokio::task::yield_now().await; - ConnectionPoolGuard { - guard, - _not_send: PhantomData, - } - } + response.send(payload)?; + Ok(()) +} - pub async fn snapshot<'a>(self: &'a Arc) -> ServerSnapshot<'a> { - ServerSnapshot { - connection_pool: self.connection_pool().await, - peer: &self.peer, - } - } +async fn save_buffer( + request: proto::SaveBuffer, + response: Response, + session: Session, +) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let collaborators = session + .db() + .await + .project_collaborators(project_id, session.connection_id) + .await?; + let host = collaborators + .into_iter() + .find(|collaborator| collaborator.is_host) + .ok_or_else(|| anyhow!("host not found"))?; + let host_connection_id = ConnectionId(host.connection_id as u32); + let response_payload = session + .peer + .forward_request(session.connection_id, host_connection_id, request.clone()) + .await?; + + let mut collaborators = session + .db() + .await + .project_collaborators(project_id, session.connection_id) + .await?; + collaborators + .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32); + let project_connection_ids = collaborators + .into_iter() + .map(|collaborator| ConnectionId(collaborator.connection_id as u32)); + broadcast(host_connection_id, project_connection_ids, |conn_id| { + session + .peer + .forward_send(host_connection_id, conn_id, response_payload.clone()) + }); + response.send(response_payload)?; + Ok(()) } -impl<'a> Deref for ConnectionPoolGuard<'a> { - type Target = ConnectionPool; +async fn create_buffer_for_peer( + request: proto::CreateBufferForPeer, + session: Session, +) -> Result<()> { + session.peer.forward_send( + session.connection_id, + ConnectionId(request.peer_id), + request, + )?; + Ok(()) +} - fn deref(&self) -> &Self::Target { - &*self.guard - } +async fn update_buffer( + request: proto::UpdateBuffer, + response: Response, + session: Session, +) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + + broadcast( + session.connection_id, + project_connection_ids, + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + response.send(proto::Ack {})?; + Ok(()) } -impl<'a> DerefMut for ConnectionPoolGuard<'a> { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut *self.guard - } +async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + + broadcast( + session.connection_id, + project_connection_ids, + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + Ok(()) } -impl<'a> Drop for ConnectionPoolGuard<'a> { - fn drop(&mut self) { - #[cfg(test)] - self.check_invariants(); - } +async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + broadcast( + session.connection_id, + project_connection_ids, + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + Ok(()) } -impl Executor for RealExecutor { - type Sleep = Sleep; +async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + broadcast( + session.connection_id, + project_connection_ids, + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + Ok(()) +} - fn spawn_detached>(&self, future: F) { - tokio::task::spawn(future); - } +async fn follow( + request: proto::Follow, + response: Response, + session: Session, +) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let leader_id = ConnectionId(request.leader_id); + let follower_id = session.connection_id; + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + + if !project_connection_ids.contains(&leader_id) { + Err(anyhow!("no such peer"))?; + } + + let mut response_payload = session + .peer + .forward_request(session.connection_id, leader_id, request) + .await?; + response_payload + .views + .retain(|view| view.leader_id != Some(follower_id.0)); + response.send(response_payload)?; + Ok(()) +} - fn sleep(&self, duration: Duration) -> Self::Sleep { - tokio::time::sleep(duration) - } +async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let leader_id = ConnectionId(request.leader_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + if !project_connection_ids.contains(&leader_id) { + Err(anyhow!("no such peer"))?; + } + session + .peer + .forward_send(session.connection_id, leader_id, request)?; + Ok(()) } -fn broadcast( - sender_id: ConnectionId, - receiver_ids: impl IntoIterator, - mut f: F, -) where - F: FnMut(ConnectionId) -> anyhow::Result<()>, -{ - for receiver_id in receiver_ids { - if receiver_id != sender_id { - f(receiver_id).trace_err(); +async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db + .lock() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + + let leader_id = request.variant.as_ref().and_then(|variant| match variant { + proto::update_followers::Variant::CreateView(payload) => payload.leader_id, + proto::update_followers::Variant::UpdateView(payload) => payload.leader_id, + proto::update_followers::Variant::UpdateActiveView(payload) => payload.leader_id, + }); + for follower_id in &request.follower_ids { + let follower_id = ConnectionId(*follower_id); + if project_connection_ids.contains(&follower_id) && Some(follower_id.0) != leader_id { + session + .peer + .forward_send(session.connection_id, follower_id, request.clone())?; } } + Ok(()) } -lazy_static! { - static ref ZED_PROTOCOL_VERSION: HeaderName = HeaderName::from_static("x-zed-protocol-version"); +async fn get_users( + request: proto::GetUsers, + response: Response, + session: Session, +) -> Result<()> { + let user_ids = request + .user_ids + .into_iter() + .map(UserId::from_proto) + .collect(); + let users = session + .db() + .await + .get_users_by_ids(user_ids) + .await? + .into_iter() + .map(|user| proto::User { + id: user.id.to_proto(), + avatar_url: format!("https://github.com/{}.png?size=128", user.github_login), + github_login: user.github_login, + }) + .collect(); + response.send(proto::UsersResponse { users })?; + Ok(()) } -pub struct ProtocolVersion(u32); +async fn fuzzy_search_users( + request: proto::FuzzySearchUsers, + response: Response, + session: Session, +) -> Result<()> { + let query = request.query; + let users = match query.len() { + 0 => vec![], + 1 | 2 => session + .db() + .await + .get_user_by_github_account(&query, None) + .await? + .into_iter() + .collect(), + _ => session.db().await.fuzzy_search_users(&query, 10).await?, + }; + let users = users + .into_iter() + .filter(|user| user.id != session.user_id) + .map(|user| proto::User { + id: user.id.to_proto(), + avatar_url: format!("https://github.com/{}.png?size=128", user.github_login), + github_login: user.github_login, + }) + .collect(); + response.send(proto::UsersResponse { users })?; + Ok(()) +} -impl Header for ProtocolVersion { - fn name() -> &'static HeaderName { - &ZED_PROTOCOL_VERSION +async fn request_contact( + request: proto::RequestContact, + response: Response, + session: Session, +) -> Result<()> { + let requester_id = session.user_id; + let responder_id = UserId::from_proto(request.responder_id); + if requester_id == responder_id { + return Err(anyhow!("cannot add yourself as a contact"))?; } - fn decode<'i, I>(values: &mut I) -> Result - where - Self: Sized, - I: Iterator, + session + .db() + .await + .send_contact_request(requester_id, responder_id) + .await?; + + // Update outgoing contact requests of requester + let mut update = proto::UpdateContacts::default(); + update.outgoing_requests.push(responder_id.to_proto()); + for connection_id in session + .connection_pool() + .await + .user_connection_ids(requester_id) { - let version = values - .next() - .ok_or_else(axum::headers::Error::invalid)? - .to_str() - .map_err(|_| axum::headers::Error::invalid())? - .parse() - .map_err(|_| axum::headers::Error::invalid())?; - Ok(Self(version)) + session.peer.send(connection_id, update.clone())?; } - fn encode>(&self, values: &mut E) { - values.extend([self.0.to_string().parse().unwrap()]); + // Update incoming contact requests of responder + let mut update = proto::UpdateContacts::default(); + update + .incoming_requests + .push(proto::IncomingContactRequest { + requester_id: requester_id.to_proto(), + should_notify: true, + }); + for connection_id in session + .connection_pool() + .await + .user_connection_ids(responder_id) + { + session.peer.send(connection_id, update.clone())?; } -} -pub fn routes(server: Arc) -> Router { - Router::new() - .route("/rpc", get(handle_websocket_request)) - .layer( - ServiceBuilder::new() - .layer(Extension(server.app_state.clone())) - .layer(middleware::from_fn(auth::validate_header)), - ) - .route("/metrics", get(handle_metrics)) - .layer(Extension(server)) + response.send(proto::Ack {})?; + Ok(()) } -pub async fn handle_websocket_request( - TypedHeader(ProtocolVersion(protocol_version)): TypedHeader, - ConnectInfo(socket_address): ConnectInfo, - Extension(server): Extension>, - Extension(user): Extension, - ws: WebSocketUpgrade, -) -> axum::response::Response { - if protocol_version != rpc::PROTOCOL_VERSION { - return ( - StatusCode::UPGRADE_REQUIRED, - "client must be upgraded".to_string(), - ) - .into_response(); - } - let socket_address = socket_address.to_string(); - ws.on_upgrade(move |socket| { - use util::ResultExt; - let socket = socket - .map_ok(to_tungstenite_message) - .err_into() - .with(|message| async move { Ok(to_axum_message(message)) }); - let connection = Connection::new(Box::pin(socket)); - async move { - server - .handle_connection(connection, socket_address, user, None, RealExecutor) - .await - .log_err(); +async fn respond_to_contact_request( + request: proto::RespondToContactRequest, + response: Response, + session: Session, +) -> Result<()> { + let responder_id = session.user_id; + let requester_id = UserId::from_proto(request.requester_id); + let db = session.db().await; + if request.response == proto::ContactRequestResponse::Dismiss as i32 { + db.dismiss_contact_notification(responder_id, requester_id) + .await?; + } else { + let accept = request.response == proto::ContactRequestResponse::Accept as i32; + + db.respond_to_contact_request(responder_id, requester_id, accept) + .await?; + let busy = db.is_user_busy(requester_id).await?; + + let pool = session.connection_pool().await; + // Update responder with new contact + let mut update = proto::UpdateContacts::default(); + if accept { + update + .contacts + .push(contact_for_user(requester_id, false, busy, &pool)); } - }) + update + .remove_incoming_requests + .push(requester_id.to_proto()); + for connection_id in pool.user_connection_ids(responder_id) { + session.peer.send(connection_id, update.clone())?; + } + + // Update requester with new contact + let mut update = proto::UpdateContacts::default(); + if accept { + update + .contacts + .push(contact_for_user(responder_id, true, busy, &pool)); + } + update + .remove_outgoing_requests + .push(responder_id.to_proto()); + for connection_id in pool.user_connection_ids(requester_id) { + session.peer.send(connection_id, update.clone())?; + } + } + + response.send(proto::Ack {})?; + Ok(()) } -pub async fn handle_metrics(Extension(server): Extension>) -> Result { - let connections = server - .connection_pool() - .await - .connections() - .filter(|connection| !connection.admin) - .count(); +async fn remove_contact( + request: proto::RemoveContact, + response: Response, + session: Session, +) -> Result<()> { + let requester_id = session.user_id; + let responder_id = UserId::from_proto(request.user_id); + let db = session.db().await; + db.remove_contact(requester_id, responder_id).await?; + + let pool = session.connection_pool().await; + // Update outgoing contact requests of requester + let mut update = proto::UpdateContacts::default(); + update + .remove_outgoing_requests + .push(responder_id.to_proto()); + for connection_id in pool.user_connection_ids(requester_id) { + session.peer.send(connection_id, update.clone())?; + } - METRIC_CONNECTIONS.set(connections as _); + // Update incoming contact requests of responder + let mut update = proto::UpdateContacts::default(); + update + .remove_incoming_requests + .push(requester_id.to_proto()); + for connection_id in pool.user_connection_ids(responder_id) { + session.peer.send(connection_id, update.clone())?; + } - let shared_projects = server.app_state.db.project_count_excluding_admins().await?; - METRIC_SHARED_PROJECTS.set(shared_projects as _); + response.send(proto::Ack {})?; + Ok(()) +} - let encoder = prometheus::TextEncoder::new(); - let metric_families = prometheus::gather(); - let encoded_metrics = encoder - .encode_to_string(&metric_families) - .map_err(|err| anyhow!("{}", err))?; - Ok(encoded_metrics) +async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> { + let project_id = ProjectId::from_proto(request.project_id); + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; + broadcast( + session.connection_id, + project_connection_ids, + |connection_id| { + session + .peer + .forward_send(session.connection_id, connection_id, request.clone()) + }, + ); + Ok(()) +} + +async fn get_private_user_info( + _request: proto::GetPrivateUserInfo, + response: Response, + session: Session, +) -> Result<()> { + let metrics_id = session + .db() + .await + .get_user_metrics_id(session.user_id) + .await?; + let user = session + .db() + .await + .get_user_by_id(session.user_id) + .await? + .ok_or_else(|| anyhow!("user not found"))?; + response.send(proto::GetPrivateUserInfoResponse { + metrics_id, + staff: user.admin, + })?; + Ok(()) } fn to_axum_message(message: TungsteniteMessage) -> AxumMessage { @@ -1941,6 +1775,137 @@ fn contact_for_user( } } +fn room_updated(room: &proto::Room, session: &Session) { + for participant in &room.participants { + session + .peer + .send( + ConnectionId(participant.peer_id), + proto::RoomUpdated { + room: Some(room.clone()), + }, + ) + .trace_err(); + } +} + +async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> { + let db = session.db().await; + let contacts = db.get_contacts(user_id).await?; + let busy = db.is_user_busy(user_id).await?; + + let pool = session.connection_pool().await; + let updated_contact = contact_for_user(user_id, false, busy, &pool); + for contact in contacts { + if let db::Contact::Accepted { + user_id: contact_user_id, + .. + } = contact + { + for contact_conn_id in pool.user_connection_ids(contact_user_id) { + session + .peer + .send( + contact_conn_id, + proto::UpdateContacts { + contacts: vec![updated_contact.clone()], + remove_contacts: Default::default(), + incoming_requests: Default::default(), + remove_incoming_requests: Default::default(), + outgoing_requests: Default::default(), + remove_outgoing_requests: Default::default(), + }, + ) + .trace_err(); + } + } + } + Ok(()) +} + +async fn leave_room_for_session(session: &Session) -> Result<()> { + let mut contacts_to_update = HashSet::default(); + + let Some(left_room) = session.db().await.leave_room(session.connection_id).await? else { + return Err(anyhow!("no room to leave"))?; + }; + contacts_to_update.insert(session.user_id); + + for project in left_room.left_projects.into_values() { + for connection_id in project.connection_ids { + if project.host_user_id == session.user_id { + session + .peer + .send( + connection_id, + proto::UnshareProject { + project_id: project.id.to_proto(), + }, + ) + .trace_err(); + } else { + session + .peer + .send( + connection_id, + proto::RemoveProjectCollaborator { + project_id: project.id.to_proto(), + peer_id: session.connection_id.0, + }, + ) + .trace_err(); + } + } + + session + .peer + .send( + session.connection_id, + proto::UnshareProject { + project_id: project.id.to_proto(), + }, + ) + .trace_err(); + } + + room_updated(&left_room.room, &session); + { + let pool = session.connection_pool().await; + for canceled_user_id in left_room.canceled_calls_to_user_ids { + for connection_id in pool.user_connection_ids(canceled_user_id) { + session + .peer + .send(connection_id, proto::CallCanceled {}) + .trace_err(); + } + contacts_to_update.insert(canceled_user_id); + } + } + + for contact_user_id in contacts_to_update { + update_user_contacts(contact_user_id, &session).await?; + } + + if let Some(live_kit) = session.live_kit_client.as_ref() { + live_kit + .remove_participant( + left_room.room.live_kit_room.clone(), + session.connection_id.to_string(), + ) + .await + .trace_err(); + + if left_room.room.participants.is_empty() { + live_kit + .delete_room(left_room.room.live_kit_room) + .await + .trace_err(); + } + } + + Ok(()) +} + pub trait ResultExt { type Ok; From 4c1b4953c17b48c19b57d6e9eb0247059f5de85f Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 18 Nov 2022 20:18:48 +0100 Subject: [PATCH 057/109] Remove version from `Room` We won't need it once we add the per-room lock. --- crates/call/src/room.rs | 10 --- .../20221109000000_test_schema.sql | 1 - .../20221111092550_reconnection_support.sql | 1 - crates/collab/src/db.rs | 71 +++++++++---------- crates/rpc/proto/zed.proto | 7 +- 5 files changed, 37 insertions(+), 53 deletions(-) diff --git a/crates/call/src/room.rs b/crates/call/src/room.rs index 8c1b0d9de09f42ecf48e10d67c31b1a6b5508350..f8a55a3a931a9d349cb4c1a38db753d9e92846cd 100644 --- a/crates/call/src/room.rs +++ b/crates/call/src/room.rs @@ -34,7 +34,6 @@ pub enum Event { pub struct Room { id: u64, - version: u64, live_kit: Option, status: RoomStatus, local_participant: LocalParticipant, @@ -62,7 +61,6 @@ impl Entity for Room { impl Room { fn new( id: u64, - version: u64, live_kit_connection_info: Option, client: Arc, user_store: ModelHandle, @@ -135,7 +133,6 @@ impl Room { Self { id, - version, live_kit: live_kit_room, status: RoomStatus::Online, participant_user_ids: Default::default(), @@ -164,7 +161,6 @@ impl Room { let room = cx.add_model(|cx| { Self::new( room_proto.id, - room_proto.version, response.live_kit_connection_info, client, user_store, @@ -209,7 +205,6 @@ impl Room { let room = cx.add_model(|cx| { Self::new( room_id, - 0, response.live_kit_connection_info, client, user_store, @@ -321,10 +316,6 @@ impl Room { futures::join!(remote_participants, pending_participants); this.update(&mut cx, |this, cx| { - if this.version >= room.version { - return; - } - this.participant_user_ids.clear(); if let Some(participant) = local_participant { @@ -429,7 +420,6 @@ impl Room { let _ = this.leave(cx); } - this.version = room.version; this.check_invariants(); cx.notify(); }); diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 66925fddd55fba36464eef2fab7b4f30af75362f..02ca0c75a9d40132970cf08d8961d828e4d2f07f 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -36,7 +36,6 @@ CREATE INDEX "index_contacts_user_id_b" ON "contacts" ("user_id_b"); CREATE TABLE "rooms" ( "id" INTEGER PRIMARY KEY, - "version" INTEGER NOT NULL, "live_kit_room" VARCHAR NOT NULL ); diff --git a/crates/collab/migrations/20221111092550_reconnection_support.sql b/crates/collab/migrations/20221111092550_reconnection_support.sql index 2b8f7824cb4bea6a138fc983ee206d69464aedf0..b742f8e0cd0b2595641b77f756687ad17cdd9aba 100644 --- a/crates/collab/migrations/20221111092550_reconnection_support.sql +++ b/crates/collab/migrations/20221111092550_reconnection_support.sql @@ -1,6 +1,5 @@ CREATE TABLE IF NOT EXISTS "rooms" ( "id" SERIAL PRIMARY KEY, - "version" INTEGER NOT NULL, "live_kit_room" VARCHAR NOT NULL ); diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 6cb53738817c567a887335ad3a1f41c5c24be859..54d4497833f10fb7b6b1d5aa3901ad59d176903d 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -931,13 +931,12 @@ where let live_kit_room = nanoid::nanoid!(30); let room_id = sqlx::query_scalar( " - INSERT INTO rooms (live_kit_room, version) - VALUES ($1, $2) + INSERT INTO rooms (live_kit_room) + VALUES ($1) RETURNING id ", ) .bind(&live_kit_room) - .bind(0) .fetch_one(&mut tx) .await .map(RoomId)?; @@ -956,7 +955,9 @@ where .execute(&mut tx) .await?; - self.commit_room_transaction(room_id, tx).await + let room = self.get_room(room_id, &mut tx).await?; + tx.commit().await?; + Ok(room) }).await } @@ -983,7 +984,9 @@ where .execute(&mut tx) .await?; - let room = self.commit_room_transaction(room_id, tx).await?; + let room = self.get_room(room_id, &mut tx).await?; + tx.commit().await?; + let incoming_call = Self::build_incoming_call(&room, called_user_id) .ok_or_else(|| anyhow!("failed to build incoming call"))?; Ok((room, incoming_call)) @@ -1061,7 +1064,9 @@ where .execute(&mut tx) .await?; - self.commit_room_transaction(room_id, tx).await + let room = self.get_room(room_id, &mut tx).await?; + tx.commit().await?; + Ok(room) }) .await } @@ -1086,7 +1091,9 @@ where return Err(anyhow!("declining call on unexpected room"))?; } - self.commit_room_transaction(room_id, tx).await + let room = self.get_room(room_id, &mut tx).await?; + tx.commit().await?; + Ok(room) }) .await } @@ -1113,7 +1120,9 @@ where return Err(anyhow!("canceling call on unexpected room"))?; } - self.commit_room_transaction(room_id, tx).await + let room = self.get_room(room_id, &mut tx).await?; + tx.commit().await?; + Ok(room) }).await } @@ -1137,7 +1146,10 @@ where .bind(user_id) .fetch_one(&mut tx) .await?; - self.commit_room_transaction(room_id, tx).await + + let room = self.get_room(room_id, &mut tx).await?; + tx.commit().await?; + Ok(room) }) .await } @@ -1245,7 +1257,9 @@ where .execute(&mut tx) .await?; - let room = self.commit_room_transaction(room_id, tx).await?; + let room = self.get_room(room_id, &mut tx).await?; + tx.commit().await?; + Ok(Some(LeftRoom { room, left_projects, @@ -1302,32 +1316,13 @@ where .fetch_one(&mut tx) .await?; - self.commit_room_transaction(room_id, tx).await + let room = self.get_room(room_id, &mut tx).await?; + tx.commit().await?; + Ok(room) }) .await } - async fn commit_room_transaction( - &self, - room_id: RoomId, - mut tx: sqlx::Transaction<'_, D>, - ) -> Result { - sqlx::query( - " - UPDATE rooms - SET version = version + 1 - WHERE id = $1 - ", - ) - .bind(room_id) - .execute(&mut tx) - .await?; - let room = self.get_room(room_id, &mut tx).await?; - tx.commit().await?; - - Ok(room) - } - async fn get_guest_connection_ids( &self, project_id: ProjectId, @@ -1455,7 +1450,6 @@ where Ok(proto::Room { id: room.id.to_proto(), - version: room.version as u64, live_kit_room: room.live_kit_room, participants: participants.into_values().collect(), pending_participants, @@ -1565,7 +1559,9 @@ where .execute(&mut tx) .await?; - let room = self.commit_room_transaction(room_id, tx).await?; + let room = self.get_room(room_id, &mut tx).await?; + tx.commit().await?; + Ok((project_id, room)) }) .await @@ -1589,7 +1585,8 @@ where .bind(connection_id.0 as i32) .fetch_one(&mut tx) .await?; - let room = self.commit_room_transaction(room_id, tx).await?; + let room = self.get_room(room_id, &mut tx).await?; + tx.commit().await?; Ok((room, guest_connection_ids)) }) @@ -1666,7 +1663,8 @@ where query.execute(&mut tx).await?; let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; - let room = self.commit_room_transaction(room_id, tx).await?; + let room = self.get_room(room_id, &mut tx).await?; + tx.commit().await?; Ok((room, guest_connection_ids)) }) @@ -2614,7 +2612,6 @@ id_type!(RoomId); #[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)] pub struct Room { pub id: RoomId, - pub version: i32, pub live_kit_room: String, } diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index 30c1c89e8f8b393f96e13c96ad9ea42e14ff7a7e..6f26e0dfa14727053a0e205dd031346dfe393d18 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -160,10 +160,9 @@ message LeaveRoom {} message Room { uint64 id = 1; - uint64 version = 2; - repeated Participant participants = 3; - repeated PendingParticipant pending_participants = 4; - string live_kit_room = 5; + repeated Participant participants = 2; + repeated PendingParticipant pending_participants = 3; + string live_kit_room = 4; } message Participant { From ae11e4f798f8a0af13f4bd46bf32ddd33602cd3a Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 28 Nov 2022 13:56:03 +0100 Subject: [PATCH 058/109] Check the correct serialization failure code when retrying transaction --- crates/collab/src/db.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 54d4497833f10fb7b6b1d5aa3901ad59d176903d..295234af618a662e2717d4a123ca740f35269781 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -2520,7 +2520,7 @@ where .as_database_error() .and_then(|error| error.code()) .as_deref() - == Some("hey") => + == Some("40001") => { // Retry (don't break the loop) } From b0e1d6bc7f5cd6986ab1666639d866207d72ee44 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 28 Nov 2022 13:57:15 +0100 Subject: [PATCH 059/109] Fix integration test incorrectly assuming a certain ordering --- crates/collab/src/integration_tests.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index 006598a6b191e593c7934d145a3c146da0a7c496..cf6bb8af3ad80251b1d1f5b9ddf12c577eb3977f 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -2422,7 +2422,10 @@ async fn test_collaborating_with_diagnostics( // Join project as client C and observe the diagnostics. let project_c = client_c.build_remote_project(project_id, cx_c).await; - let project_c_diagnostic_summaries = Rc::new(RefCell::new(Vec::new())); + let project_c_diagnostic_summaries = + Rc::new(RefCell::new(project_c.read_with(cx_c, |project, cx| { + project.diagnostic_summaries(cx).collect::>() + }))); project_c.update(cx_c, |_, cx| { let summaries = project_c_diagnostic_summaries.clone(); cx.subscribe(&project_c, { From 5581674f8f4a8b256d986f20e0ddb4c1d84bc0af Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 28 Nov 2022 14:39:27 +0100 Subject: [PATCH 060/109] After completing LSP request, return an error if guest is disconnected --- crates/project/src/project.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 503ae8d4b24cc290e539121e50e2803939a9ecc7..30b0ac25060b16548b37ee8165d84bbd976356b4 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -4109,9 +4109,13 @@ impl Project { let message = request.to_proto(project_id, buffer); return cx.spawn(|this, cx| async move { let response = rpc.request(message).await?; - request - .response_from_proto(response, this, buffer_handle, cx) - .await + if this.read_with(&cx, |this, _| this.is_read_only()) { + Err(anyhow!("disconnected before completing request")) + } else { + request + .response_from_proto(response, this, buffer_handle, cx) + .await + } }); } Task::ready(Ok(Default::default())) From 2a0ddd99d28ab53d0e5df72145584f9a8949a48f Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 28 Nov 2022 15:05:34 +0100 Subject: [PATCH 061/109] Error if project is disconnected after getting code actions response --- crates/project/src/project.rs | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 30b0ac25060b16548b37ee8165d84bbd976356b4..fb77da9347db5bc3099cfc11f8461994c1becb43 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -3579,7 +3579,7 @@ impl Project { } else if let Some(project_id) = self.remote_id() { let rpc = self.client.clone(); let version = buffer.version(); - cx.spawn_weak(|_, mut cx| async move { + cx.spawn_weak(|this, mut cx| async move { let response = rpc .request(proto::GetCodeActions { project_id, @@ -3590,17 +3590,27 @@ impl Project { }) .await?; - buffer_handle - .update(&mut cx, |buffer, _| { - buffer.wait_for_version(deserialize_version(response.version)) - }) - .await; + if this + .upgrade(&cx) + .ok_or_else(|| anyhow!("project was dropped"))? + .read_with(&cx, |this, _| this.is_read_only()) + { + return Err(anyhow!( + "failed to get code actions: project was disconnected" + )); + } else { + buffer_handle + .update(&mut cx, |buffer, _| { + buffer.wait_for_version(deserialize_version(response.version)) + }) + .await; - response - .actions - .into_iter() - .map(language::proto::deserialize_code_action) - .collect() + response + .actions + .into_iter() + .map(language::proto::deserialize_code_action) + .collect() + } }) } else { Task::ready(Ok(Default::default())) From cd0b663f6285f24f74a6445bc870b2e94ab610cd Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 28 Nov 2022 17:00:47 +0100 Subject: [PATCH 062/109] Introduce per-room lock acquired before committing a transaction --- Cargo.lock | 14 ++ crates/collab/Cargo.toml | 1 + crates/collab/src/db.rs | 254 ++++++++++++++++++++++------------- crates/collab/src/rpc.rs | 283 +++++++++++++++++++++------------------ 4 files changed, 328 insertions(+), 224 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b6f86980ae5f792a9d22fb6936599b6a5ab9cf4b..8cd5e7d6d7ba748271c0e230ed1b4682e1bb50dc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1041,6 +1041,7 @@ dependencies = [ "client", "collections", "ctor", + "dashmap", "editor", "env_logger", "envy", @@ -1536,6 +1537,19 @@ dependencies = [ "syn", ] +[[package]] +name = "dashmap" +version = "5.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc" +dependencies = [ + "cfg-if 1.0.0", + "hashbrown 0.12.3", + "lock_api", + "once_cell", + "parking_lot_core 0.9.4", +] + [[package]] name = "data-url" version = "0.1.1" diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index f04918605ff6a1e6e4911cbfeb01e7de045b6525..e5a97b9764d2d476af4f22dff89147f2cf06698b 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -24,6 +24,7 @@ axum = { version = "0.5", features = ["json", "headers", "ws"] } axum-extra = { version = "0.3", features = ["erased-json"] } base64 = "0.13" clap = { version = "3.1", features = ["derive"], optional = true } +dashmap = "5.4" envy = "0.4.2" futures = "0.3" hyper = "0.14" diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 295234af618a662e2717d4a123ca740f35269781..84ad5082d017e616d2bfdd37523818be61dd1f86 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -2,6 +2,7 @@ use crate::{Error, Result}; use anyhow::anyhow; use axum::http::StatusCode; use collections::{BTreeMap, HashMap, HashSet}; +use dashmap::DashMap; use futures::{future::BoxFuture, FutureExt, StreamExt}; use rpc::{proto, ConnectionId}; use serde::{Deserialize, Serialize}; @@ -10,8 +11,17 @@ use sqlx::{ types::Uuid, FromRow, }; -use std::{future::Future, path::Path, time::Duration}; +use std::{ + future::Future, + marker::PhantomData, + ops::{Deref, DerefMut}, + path::Path, + rc::Rc, + sync::Arc, + time::Duration, +}; use time::{OffsetDateTime, PrimitiveDateTime}; +use tokio::sync::{Mutex, OwnedMutexGuard}; #[cfg(test)] pub type DefaultDb = Db; @@ -21,12 +31,33 @@ pub type DefaultDb = Db; pub struct Db { pool: sqlx::Pool, + rooms: DashMap>>, #[cfg(test)] background: Option>, #[cfg(test)] runtime: Option, } +pub struct RoomGuard { + data: T, + _guard: OwnedMutexGuard<()>, + _not_send: PhantomData>, +} + +impl Deref for RoomGuard { + type Target = T; + + fn deref(&self) -> &T { + &self.data + } +} + +impl DerefMut for RoomGuard { + fn deref_mut(&mut self) -> &mut T { + &mut self.data + } +} + pub trait BeginTransaction: Send + Sync { type Database: sqlx::Database; @@ -90,6 +121,7 @@ impl Db { .await?; Ok(Self { pool, + rooms: Default::default(), background: None, runtime: None, }) @@ -197,6 +229,7 @@ impl Db { .await?; Ok(Self { pool, + rooms: DashMap::with_capacity(16384), #[cfg(test)] background: None, #[cfg(test)] @@ -922,13 +955,29 @@ where .await } + async fn commit_room_transaction<'a, T>( + &'a self, + room_id: RoomId, + tx: sqlx::Transaction<'static, D>, + data: T, + ) -> Result> { + let lock = self.rooms.entry(room_id).or_default().clone(); + let _guard = lock.lock_owned().await; + tx.commit().await?; + Ok(RoomGuard { + data, + _guard, + _not_send: PhantomData, + }) + } + pub async fn create_room( &self, user_id: UserId, connection_id: ConnectionId, - ) -> Result { + live_kit_room: &str, + ) -> Result> { self.transact(|mut tx| async move { - let live_kit_room = nanoid::nanoid!(30); let room_id = sqlx::query_scalar( " INSERT INTO rooms (live_kit_room) @@ -956,8 +1005,7 @@ where .await?; let room = self.get_room(room_id, &mut tx).await?; - tx.commit().await?; - Ok(room) + self.commit_room_transaction(room_id, tx, room).await }).await } @@ -968,11 +1016,17 @@ where calling_connection_id: ConnectionId, called_user_id: UserId, initial_project_id: Option, - ) -> Result<(proto::Room, proto::IncomingCall)> { + ) -> Result> { self.transact(|mut tx| async move { sqlx::query( " - INSERT INTO room_participants (room_id, user_id, calling_user_id, calling_connection_id, initial_project_id) + INSERT INTO room_participants ( + room_id, + user_id, + calling_user_id, + calling_connection_id, + initial_project_id + ) VALUES ($1, $2, $3, $4, $5) ", ) @@ -985,12 +1039,12 @@ where .await?; let room = self.get_room(room_id, &mut tx).await?; - tx.commit().await?; - let incoming_call = Self::build_incoming_call(&room, called_user_id) .ok_or_else(|| anyhow!("failed to build incoming call"))?; - Ok((room, incoming_call)) - }).await + self.commit_room_transaction(room_id, tx, (room, incoming_call)) + .await + }) + .await } pub async fn incoming_call_for_user( @@ -1051,7 +1105,7 @@ where &self, room_id: RoomId, called_user_id: UserId, - ) -> Result { + ) -> Result> { self.transact(|mut tx| async move { sqlx::query( " @@ -1065,8 +1119,7 @@ where .await?; let room = self.get_room(room_id, &mut tx).await?; - tx.commit().await?; - Ok(room) + self.commit_room_transaction(room_id, tx, room).await }) .await } @@ -1075,7 +1128,7 @@ where &self, expected_room_id: Option, user_id: UserId, - ) -> Result { + ) -> Result> { self.transact(|mut tx| async move { let room_id = sqlx::query_scalar( " @@ -1092,8 +1145,7 @@ where } let room = self.get_room(room_id, &mut tx).await?; - tx.commit().await?; - Ok(room) + self.commit_room_transaction(room_id, tx, room).await }) .await } @@ -1103,7 +1155,7 @@ where expected_room_id: Option, calling_connection_id: ConnectionId, called_user_id: UserId, - ) -> Result { + ) -> Result> { self.transact(|mut tx| async move { let room_id = sqlx::query_scalar( " @@ -1121,8 +1173,7 @@ where } let room = self.get_room(room_id, &mut tx).await?; - tx.commit().await?; - Ok(room) + self.commit_room_transaction(room_id, tx, room).await }).await } @@ -1131,7 +1182,7 @@ where room_id: RoomId, user_id: UserId, connection_id: ConnectionId, - ) -> Result { + ) -> Result> { self.transact(|mut tx| async move { sqlx::query( " @@ -1148,13 +1199,15 @@ where .await?; let room = self.get_room(room_id, &mut tx).await?; - tx.commit().await?; - Ok(room) + self.commit_room_transaction(room_id, tx, room).await }) .await } - pub async fn leave_room(&self, connection_id: ConnectionId) -> Result> { + pub async fn leave_room( + &self, + connection_id: ConnectionId, + ) -> Result>> { self.transact(|mut tx| async move { // Leave room. let room_id = sqlx::query_scalar::<_, RoomId>( @@ -1258,13 +1311,18 @@ where .await?; let room = self.get_room(room_id, &mut tx).await?; - tx.commit().await?; - - Ok(Some(LeftRoom { - room, - left_projects, - canceled_calls_to_user_ids, - })) + Ok(Some( + self.commit_room_transaction( + room_id, + tx, + LeftRoom { + room, + left_projects, + canceled_calls_to_user_ids, + }, + ) + .await?, + )) } else { Ok(None) } @@ -1277,7 +1335,7 @@ where room_id: RoomId, connection_id: ConnectionId, location: proto::ParticipantLocation, - ) -> Result { + ) -> Result> { self.transact(|tx| async { let mut tx = tx; let location_kind; @@ -1317,8 +1375,7 @@ where .await?; let room = self.get_room(room_id, &mut tx).await?; - tx.commit().await?; - Ok(room) + self.commit_room_transaction(room_id, tx, room).await }) .await } @@ -1478,7 +1535,7 @@ where expected_room_id: RoomId, connection_id: ConnectionId, worktrees: &[proto::WorktreeMetadata], - ) -> Result<(ProjectId, proto::Room)> { + ) -> Result> { self.transact(|mut tx| async move { let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>( " @@ -1560,9 +1617,8 @@ where .await?; let room = self.get_room(room_id, &mut tx).await?; - tx.commit().await?; - - Ok((project_id, room)) + self.commit_room_transaction(room_id, tx, (project_id, room)) + .await }) .await } @@ -1571,7 +1627,7 @@ where &self, project_id: ProjectId, connection_id: ConnectionId, - ) -> Result<(proto::Room, Vec)> { + ) -> Result)>> { self.transact(|mut tx| async move { let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; let room_id: RoomId = sqlx::query_scalar( @@ -1586,9 +1642,8 @@ where .fetch_one(&mut tx) .await?; let room = self.get_room(room_id, &mut tx).await?; - tx.commit().await?; - - Ok((room, guest_connection_ids)) + self.commit_room_transaction(room_id, tx, (room, guest_connection_ids)) + .await }) .await } @@ -1598,7 +1653,7 @@ where project_id: ProjectId, connection_id: ConnectionId, worktrees: &[proto::WorktreeMetadata], - ) -> Result<(proto::Room, Vec)> { + ) -> Result)>> { self.transact(|mut tx| async move { let room_id: RoomId = sqlx::query_scalar( " @@ -1664,9 +1719,8 @@ where let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; let room = self.get_room(room_id, &mut tx).await?; - tx.commit().await?; - - Ok((room, guest_connection_ids)) + self.commit_room_transaction(room_id, tx, (room, guest_connection_ids)) + .await }) .await } @@ -1675,15 +1729,15 @@ where &self, update: &proto::UpdateWorktree, connection_id: ConnectionId, - ) -> Result> { + ) -> Result>> { self.transact(|mut tx| async move { let project_id = ProjectId::from_proto(update.project_id); let worktree_id = WorktreeId::from_proto(update.worktree_id); // Ensure the update comes from the host. - sqlx::query( + let room_id: RoomId = sqlx::query_scalar( " - SELECT 1 + SELECT room_id FROM projects WHERE id = $1 AND host_connection_id = $2 ", @@ -1781,8 +1835,8 @@ where } let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; - tx.commit().await?; - Ok(connection_ids) + self.commit_room_transaction(room_id, tx, connection_ids) + .await }) .await } @@ -1791,7 +1845,7 @@ where &self, update: &proto::UpdateDiagnosticSummary, connection_id: ConnectionId, - ) -> Result> { + ) -> Result>> { self.transact(|mut tx| async { let project_id = ProjectId::from_proto(update.project_id); let worktree_id = WorktreeId::from_proto(update.worktree_id); @@ -1801,9 +1855,9 @@ where .ok_or_else(|| anyhow!("invalid summary"))?; // Ensure the update comes from the host. - sqlx::query( + let room_id: RoomId = sqlx::query_scalar( " - SELECT 1 + SELECT room_id FROM projects WHERE id = $1 AND host_connection_id = $2 ", @@ -1841,8 +1895,8 @@ where .await?; let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; - tx.commit().await?; - Ok(connection_ids) + self.commit_room_transaction(room_id, tx, connection_ids) + .await }) .await } @@ -1851,7 +1905,7 @@ where &self, update: &proto::StartLanguageServer, connection_id: ConnectionId, - ) -> Result> { + ) -> Result>> { self.transact(|mut tx| async { let project_id = ProjectId::from_proto(update.project_id); let server = update @@ -1860,9 +1914,9 @@ where .ok_or_else(|| anyhow!("invalid language server"))?; // Ensure the update comes from the host. - sqlx::query( + let room_id: RoomId = sqlx::query_scalar( " - SELECT 1 + SELECT room_id FROM projects WHERE id = $1 AND host_connection_id = $2 ", @@ -1888,8 +1942,8 @@ where .await?; let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; - tx.commit().await?; - Ok(connection_ids) + self.commit_room_transaction(room_id, tx, connection_ids) + .await }) .await } @@ -1898,7 +1952,7 @@ where &self, project_id: ProjectId, connection_id: ConnectionId, - ) -> Result<(Project, ReplicaId)> { + ) -> Result> { self.transact(|mut tx| async move { let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>( " @@ -2068,21 +2122,25 @@ where .fetch_all(&mut tx) .await?; - tx.commit().await?; - Ok(( - Project { - collaborators, - worktrees, - language_servers: language_servers - .into_iter() - .map(|language_server| proto::LanguageServer { - id: language_server.id.to_proto(), - name: language_server.name, - }) - .collect(), - }, - replica_id as ReplicaId, - )) + self.commit_room_transaction( + room_id, + tx, + ( + Project { + collaborators, + worktrees, + language_servers: language_servers + .into_iter() + .map(|language_server| proto::LanguageServer { + id: language_server.id.to_proto(), + name: language_server.name, + }) + .collect(), + }, + replica_id as ReplicaId, + ), + ) + .await }) .await } @@ -2091,7 +2149,7 @@ where &self, project_id: ProjectId, connection_id: ConnectionId, - ) -> Result { + ) -> Result> { self.transact(|mut tx| async move { let result = sqlx::query( " @@ -2122,25 +2180,29 @@ where .map(|id| ConnectionId(id as u32)) .collect(); - let (host_user_id, host_connection_id) = sqlx::query_as::<_, (i32, i32)>( - " - SELECT host_user_id, host_connection_id + let (room_id, host_user_id, host_connection_id) = + sqlx::query_as::<_, (RoomId, i32, i32)>( + " + SELECT room_id, host_user_id, host_connection_id FROM projects WHERE id = $1 ", - ) - .bind(project_id) - .fetch_one(&mut tx) - .await?; - - tx.commit().await?; + ) + .bind(project_id) + .fetch_one(&mut tx) + .await?; - Ok(LeftProject { - id: project_id, - host_user_id: UserId(host_user_id), - host_connection_id: ConnectionId(host_connection_id as u32), - connection_ids, - }) + self.commit_room_transaction( + room_id, + tx, + LeftProject { + id: project_id, + host_user_id: UserId(host_user_id), + host_connection_id: ConnectionId(host_connection_id as u32), + connection_ids, + }, + ) + .await }) .await } @@ -2538,9 +2600,9 @@ where let result = self.runtime.as_ref().unwrap().block_on(body); - if let Some(background) = self.background.as_ref() { - background.simulate_random_delay().await; - } + // if let Some(background) = self.background.as_ref() { + // background.simulate_random_delay().await; + // } result } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index ba97b09acd1a72b0fb7340c9a4ace2e8b62cffca..07b98914808a6fcffc74710886a4d0c07d8e9a79 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -42,6 +42,7 @@ use std::{ fmt, future::Future, marker::PhantomData, + mem, net::SocketAddr, ops::{Deref, DerefMut}, rc::Rc, @@ -702,20 +703,15 @@ async fn create_room( response: Response, session: Session, ) -> Result<()> { - let room = session - .db() - .await - .create_room(session.user_id, session.connection_id) - .await?; - + let live_kit_room = nanoid::nanoid!(30); let live_kit_connection_info = if let Some(live_kit) = session.live_kit_client.as_ref() { if let Some(_) = live_kit - .create_room(room.live_kit_room.clone()) + .create_room(live_kit_room.clone()) .await .trace_err() { if let Some(token) = live_kit - .room_token(&room.live_kit_room, &session.connection_id.to_string()) + .room_token(&live_kit_room, &session.connection_id.to_string()) .trace_err() { Some(proto::LiveKitConnectionInfo { @@ -732,10 +728,19 @@ async fn create_room( None }; - response.send(proto::CreateRoomResponse { - room: Some(room), - live_kit_connection_info, - })?; + { + let room = session + .db() + .await + .create_room(session.user_id, session.connection_id, &live_kit_room) + .await?; + + response.send(proto::CreateRoomResponse { + room: Some(room.clone()), + live_kit_connection_info, + })?; + } + update_user_contacts(session.user_id, &session).await?; Ok(()) } @@ -745,15 +750,20 @@ async fn join_room( response: Response, session: Session, ) -> Result<()> { - let room = session - .db() - .await - .join_room( - RoomId::from_proto(request.id), - session.user_id, - session.connection_id, - ) - .await?; + let room = { + let room = session + .db() + .await + .join_room( + RoomId::from_proto(request.id), + session.user_id, + session.connection_id, + ) + .await?; + room_updated(&room, &session); + room.clone() + }; + for connection_id in session .connection_pool() .await @@ -781,7 +791,6 @@ async fn join_room( None }; - room_updated(&room, &session); response.send(proto::JoinRoomResponse { room: Some(room), live_kit_connection_info, @@ -814,18 +823,21 @@ async fn call( return Err(anyhow!("cannot call a user who isn't a contact"))?; } - let (room, incoming_call) = session - .db() - .await - .call( - room_id, - calling_user_id, - calling_connection_id, - called_user_id, - initial_project_id, - ) - .await?; - room_updated(&room, &session); + let incoming_call = { + let (room, incoming_call) = &mut *session + .db() + .await + .call( + room_id, + calling_user_id, + calling_connection_id, + called_user_id, + initial_project_id, + ) + .await?; + room_updated(&room, &session); + mem::take(incoming_call) + }; update_user_contacts(called_user_id, &session).await?; let mut calls = session @@ -847,12 +859,14 @@ async fn call( } } - let room = session - .db() - .await - .call_failed(room_id, called_user_id) - .await?; - room_updated(&room, &session); + { + let room = session + .db() + .await + .call_failed(room_id, called_user_id) + .await?; + room_updated(&room, &session); + } update_user_contacts(called_user_id, &session).await?; Err(anyhow!("failed to ring user"))? @@ -865,11 +879,15 @@ async fn cancel_call( ) -> Result<()> { let called_user_id = UserId::from_proto(request.called_user_id); let room_id = RoomId::from_proto(request.room_id); - let room = session - .db() - .await - .cancel_call(Some(room_id), session.connection_id, called_user_id) - .await?; + { + let room = session + .db() + .await + .cancel_call(Some(room_id), session.connection_id, called_user_id) + .await?; + room_updated(&room, &session); + } + for connection_id in session .connection_pool() .await @@ -880,7 +898,6 @@ async fn cancel_call( .send(connection_id, proto::CallCanceled {}) .trace_err(); } - room_updated(&room, &session); response.send(proto::Ack {})?; update_user_contacts(called_user_id, &session).await?; @@ -889,11 +906,15 @@ async fn cancel_call( async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<()> { let room_id = RoomId::from_proto(message.room_id); - let room = session - .db() - .await - .decline_call(Some(room_id), session.user_id) - .await?; + { + let room = session + .db() + .await + .decline_call(Some(room_id), session.user_id) + .await?; + room_updated(&room, &session); + } + for connection_id in session .connection_pool() .await @@ -904,7 +925,6 @@ async fn decline_call(message: proto::DeclineCall, session: Session) -> Result<( .send(connection_id, proto::CallCanceled {}) .trace_err(); } - room_updated(&room, &session); update_user_contacts(session.user_id, &session).await?; Ok(()) } @@ -933,7 +953,7 @@ async fn share_project( response: Response, session: Session, ) -> Result<()> { - let (project_id, room) = session + let (project_id, room) = &*session .db() .await .share_project( @@ -953,15 +973,17 @@ async fn share_project( async fn unshare_project(message: proto::UnshareProject, session: Session) -> Result<()> { let project_id = ProjectId::from_proto(message.project_id); - let (room, guest_connection_ids) = session + let (room, guest_connection_ids) = &*session .db() .await .unshare_project(project_id, session.connection_id) .await?; - broadcast(session.connection_id, guest_connection_ids, |conn_id| { - session.peer.send(conn_id, message.clone()) - }); + broadcast( + session.connection_id, + guest_connection_ids.iter().copied(), + |conn_id| session.peer.send(conn_id, message.clone()), + ); room_updated(&room, &session); Ok(()) @@ -977,7 +999,7 @@ async fn join_project( tracing::info!(%project_id, "join project"); - let (project, replica_id) = session + let (project, replica_id) = &mut *session .db() .await .join_project(project_id, session.connection_id) @@ -1029,7 +1051,7 @@ async fn join_project( language_servers: project.language_servers.clone(), })?; - for (worktree_id, worktree) in project.worktrees { + for (worktree_id, worktree) in mem::take(&mut project.worktrees) { #[cfg(any(test, feature = "test-support"))] const MAX_CHUNK_SIZE: usize = 2; #[cfg(not(any(test, feature = "test-support")))] @@ -1084,21 +1106,23 @@ async fn join_project( async fn leave_project(request: proto::LeaveProject, session: Session) -> Result<()> { let sender_id = session.connection_id; let project_id = ProjectId::from_proto(request.project_id); - let project; - { - project = session - .db() - .await - .leave_project(project_id, sender_id) - .await?; - tracing::info!( - %project_id, - host_user_id = %project.host_user_id, - host_connection_id = %project.host_connection_id, - "leave project" - ); - broadcast(sender_id, project.connection_ids, |conn_id| { + let project = session + .db() + .await + .leave_project(project_id, sender_id) + .await?; + tracing::info!( + %project_id, + host_user_id = %project.host_user_id, + host_connection_id = %project.host_connection_id, + "leave project" + ); + + broadcast( + sender_id, + project.connection_ids.iter().copied(), + |conn_id| { session.peer.send( conn_id, proto::RemoveProjectCollaborator { @@ -1106,8 +1130,8 @@ async fn leave_project(request: proto::LeaveProject, session: Session) -> Result peer_id: sender_id.0, }, ) - }); - } + }, + ); Ok(()) } @@ -1118,14 +1142,14 @@ async fn update_project( session: Session, ) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); - let (room, guest_connection_ids) = session + let (room, guest_connection_ids) = &*session .db() .await .update_project(project_id, session.connection_id, &request.worktrees) .await?; broadcast( session.connection_id, - guest_connection_ids, + guest_connection_ids.iter().copied(), |connection_id| { session .peer @@ -1151,7 +1175,7 @@ async fn update_worktree( broadcast( session.connection_id, - guest_connection_ids, + guest_connection_ids.iter().copied(), |connection_id| { session .peer @@ -1175,7 +1199,7 @@ async fn update_diagnostic_summary( broadcast( session.connection_id, - guest_connection_ids, + guest_connection_ids.iter().copied(), |connection_id| { session .peer @@ -1199,7 +1223,7 @@ async fn start_language_server( broadcast( session.connection_id, - guest_connection_ids, + guest_connection_ids.iter().copied(), |connection_id| { session .peer @@ -1826,52 +1850,61 @@ async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> async fn leave_room_for_session(session: &Session) -> Result<()> { let mut contacts_to_update = HashSet::default(); - let Some(left_room) = session.db().await.leave_room(session.connection_id).await? else { - return Err(anyhow!("no room to leave"))?; - }; - contacts_to_update.insert(session.user_id); - - for project in left_room.left_projects.into_values() { - for connection_id in project.connection_ids { - if project.host_user_id == session.user_id { - session - .peer - .send( - connection_id, - proto::UnshareProject { - project_id: project.id.to_proto(), - }, - ) - .trace_err(); - } else { - session - .peer - .send( - connection_id, - proto::RemoveProjectCollaborator { - project_id: project.id.to_proto(), - peer_id: session.connection_id.0, - }, - ) - .trace_err(); + let canceled_calls_to_user_ids; + let live_kit_room; + let delete_live_kit_room; + { + let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? else { + return Err(anyhow!("no room to leave"))?; + }; + contacts_to_update.insert(session.user_id); + + for project in left_room.left_projects.values() { + for connection_id in &project.connection_ids { + if project.host_user_id == session.user_id { + session + .peer + .send( + *connection_id, + proto::UnshareProject { + project_id: project.id.to_proto(), + }, + ) + .trace_err(); + } else { + session + .peer + .send( + *connection_id, + proto::RemoveProjectCollaborator { + project_id: project.id.to_proto(), + peer_id: session.connection_id.0, + }, + ) + .trace_err(); + } } + + session + .peer + .send( + session.connection_id, + proto::UnshareProject { + project_id: project.id.to_proto(), + }, + ) + .trace_err(); } - session - .peer - .send( - session.connection_id, - proto::UnshareProject { - project_id: project.id.to_proto(), - }, - ) - .trace_err(); + room_updated(&left_room.room, &session); + canceled_calls_to_user_ids = mem::take(&mut left_room.canceled_calls_to_user_ids); + live_kit_room = mem::take(&mut left_room.room.live_kit_room); + delete_live_kit_room = left_room.room.participants.is_empty(); } - room_updated(&left_room.room, &session); { let pool = session.connection_pool().await; - for canceled_user_id in left_room.canceled_calls_to_user_ids { + for canceled_user_id in canceled_calls_to_user_ids { for connection_id in pool.user_connection_ids(canceled_user_id) { session .peer @@ -1888,18 +1921,12 @@ async fn leave_room_for_session(session: &Session) -> Result<()> { if let Some(live_kit) = session.live_kit_client.as_ref() { live_kit - .remove_participant( - left_room.room.live_kit_room.clone(), - session.connection_id.to_string(), - ) + .remove_participant(live_kit_room.clone(), session.connection_id.to_string()) .await .trace_err(); - if left_room.room.participants.is_empty() { - live_kit - .delete_room(left_room.room.live_kit_room) - .await - .trace_err(); + if delete_live_kit_room { + live_kit.delete_room(live_kit_room).await.trace_err(); } } From af2a2d2494e2f72194aed7d4d2b012f4694e2dec Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 28 Nov 2022 17:43:40 +0100 Subject: [PATCH 063/109] Return error when waiting on a worktree snapshot after disconnecting --- crates/project/src/worktree.rs | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/crates/project/src/worktree.rs b/crates/project/src/worktree.rs index 836ac55b661157f8c2f0297567b55143b8b26d2a..791cd1d622ff8fd8cf983c55f21f1c9cd303604c 100644 --- a/crates/project/src/worktree.rs +++ b/crates/project/src/worktree.rs @@ -81,6 +81,7 @@ pub struct RemoteWorktree { replica_id: ReplicaId, diagnostic_summaries: TreeMap, visible: bool, + disconnected: bool, } #[derive(Clone)] @@ -248,6 +249,7 @@ impl Worktree { client: client.clone(), diagnostic_summaries: Default::default(), visible, + disconnected: false, }) }); @@ -1069,6 +1071,7 @@ impl RemoteWorktree { pub fn disconnected_from_host(&mut self) { self.updates_tx.take(); self.snapshot_subscriptions.clear(); + self.disconnected = true; } pub fn update_from_remote(&mut self, update: proto::UpdateWorktree) { @@ -1083,10 +1086,12 @@ impl RemoteWorktree { self.scan_id > scan_id || (self.scan_id == scan_id && self.is_complete) } - fn wait_for_snapshot(&mut self, scan_id: usize) -> impl Future { + fn wait_for_snapshot(&mut self, scan_id: usize) -> impl Future> { let (tx, rx) = oneshot::channel(); if self.observed_snapshot(scan_id) { let _ = tx.send(()); + } else if self.disconnected { + drop(tx); } else { match self .snapshot_subscriptions @@ -1097,7 +1102,8 @@ impl RemoteWorktree { } async move { - let _ = rx.await; + rx.await?; + Ok(()) } } @@ -1126,7 +1132,7 @@ impl RemoteWorktree { ) -> Task> { let wait_for_snapshot = self.wait_for_snapshot(scan_id); cx.spawn(|this, mut cx| async move { - wait_for_snapshot.await; + wait_for_snapshot.await?; this.update(&mut cx, |worktree, _| { let worktree = worktree.as_remote_mut().unwrap(); let mut snapshot = worktree.background_snapshot.lock(); @@ -1145,7 +1151,7 @@ impl RemoteWorktree { ) -> Task> { let wait_for_snapshot = self.wait_for_snapshot(scan_id); cx.spawn(|this, mut cx| async move { - wait_for_snapshot.await; + wait_for_snapshot.await?; this.update(&mut cx, |worktree, _| { let worktree = worktree.as_remote_mut().unwrap(); let mut snapshot = worktree.background_snapshot.lock(); From 0a565c6bae9e8ce1377c44bb608b2e305120ac75 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 28 Nov 2022 17:44:18 +0100 Subject: [PATCH 064/109] :lipstick: --- crates/collab/src/db.rs | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 84ad5082d017e616d2bfdd37523818be61dd1f86..eff97855c6a601992cc8c2a96d953e063b0a6cb5 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -2598,13 +2598,7 @@ where background.simulate_random_delay().await; } - let result = self.runtime.as_ref().unwrap().block_on(body); - - // if let Some(background) = self.background.as_ref() { - // background.simulate_random_delay().await; - // } - - result + self.runtime.as_ref().unwrap().block_on(body) } #[cfg(not(test))] From f0a721032d70f58469a61c399a64d24ce748752e Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 28 Nov 2022 18:56:11 +0100 Subject: [PATCH 065/109] Remove non-determinism caused by random entropy when reconnecting --- crates/client/src/client.rs | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index bad85384be6b78cce7a0b1f33d48dc471fcff22b..c75aef3a1ad51c83ef6845d9160fe8dffba2b783 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -398,7 +398,11 @@ impl Client { let this = self.clone(); let reconnect_interval = state.reconnect_interval; state._reconnect_task = Some(cx.spawn(|cx| async move { + #[cfg(any(test, feature = "test-support"))] + let mut rng = StdRng::seed_from_u64(0); + #[cfg(not(any(test, feature = "test-support")))] let mut rng = StdRng::from_entropy(); + let mut delay = INITIAL_RECONNECTION_DELAY; while let Err(error) = this.authenticate_and_connect(true, &cx).await { log::error!("failed to connect {}", error); From fa3f100effebd136ad9e2a4a53908aa979465dd3 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 28 Nov 2022 19:01:28 +0100 Subject: [PATCH 066/109] Introduce a new `detect_nondeterminism = true` attribute to `gpui::test` --- crates/gpui/src/executor.rs | 72 ++++++++++++++++++++--- crates/gpui/src/test.rs | 84 ++++++++++++++++++++++----- crates/gpui_macros/src/gpui_macros.rs | 27 ++++++--- 3 files changed, 150 insertions(+), 33 deletions(-) diff --git a/crates/gpui/src/executor.rs b/crates/gpui/src/executor.rs index 0639445b0d1f2c35a65fd9777e3d96165bcbb702..5231f8a51a0c6f5dbb25a299944251ccca125e08 100644 --- a/crates/gpui/src/executor.rs +++ b/crates/gpui/src/executor.rs @@ -66,21 +66,31 @@ struct DeterministicState { rng: rand::prelude::StdRng, seed: u64, scheduled_from_foreground: collections::HashMap>, - scheduled_from_background: Vec, + scheduled_from_background: Vec, forbid_parking: bool, block_on_ticks: std::ops::RangeInclusive, now: std::time::Instant, next_timer_id: usize, pending_timers: Vec<(usize, std::time::Instant, postage::barrier::Sender)>, waiting_backtrace: Option, + next_runnable_id: usize, + poll_history: Vec, + runnable_backtraces: collections::HashMap, } #[cfg(any(test, feature = "test-support"))] struct ForegroundRunnable { + id: usize, runnable: Runnable, main: bool, } +#[cfg(any(test, feature = "test-support"))] +struct BackgroundRunnable { + id: usize, + runnable: Runnable, +} + #[cfg(any(test, feature = "test-support"))] pub struct Deterministic { state: Arc>, @@ -117,11 +127,24 @@ impl Deterministic { next_timer_id: Default::default(), pending_timers: Default::default(), waiting_backtrace: None, + next_runnable_id: 0, + poll_history: Default::default(), + runnable_backtraces: Default::default(), })), parker: Default::default(), }) } + pub fn runnable_history(&self) -> Vec { + self.state.lock().poll_history.clone() + } + + pub fn runnable_backtrace(&self, runnable_id: usize) -> backtrace::Backtrace { + let mut backtrace = self.state.lock().runnable_backtraces[&runnable_id].clone(); + backtrace.resolve(); + backtrace + } + pub fn build_background(self: &Arc) -> Arc { Arc::new(Background::Deterministic { executor: self.clone(), @@ -142,6 +165,15 @@ impl Deterministic { main: bool, ) -> AnyLocalTask { let state = self.state.clone(); + let id; + { + let mut state = state.lock(); + id = util::post_inc(&mut state.next_runnable_id); + state + .runnable_backtraces + .insert(id, backtrace::Backtrace::new_unresolved()); + } + let unparker = self.parker.lock().unparker(); let (runnable, task) = async_task::spawn_local(future, move |runnable| { let mut state = state.lock(); @@ -149,7 +181,7 @@ impl Deterministic { .scheduled_from_foreground .entry(cx_id) .or_default() - .push(ForegroundRunnable { runnable, main }); + .push(ForegroundRunnable { id, runnable, main }); unparker.unpark(); }); runnable.schedule(); @@ -158,10 +190,21 @@ impl Deterministic { fn spawn(&self, future: AnyFuture) -> AnyTask { let state = self.state.clone(); + let id; + { + let mut state = state.lock(); + id = util::post_inc(&mut state.next_runnable_id); + state + .runnable_backtraces + .insert(id, backtrace::Backtrace::new_unresolved()); + } + let unparker = self.parker.lock().unparker(); let (runnable, task) = async_task::spawn(future, move |runnable| { let mut state = state.lock(); - state.scheduled_from_background.push(runnable); + state + .scheduled_from_background + .push(BackgroundRunnable { id, runnable }); unparker.unpark(); }); runnable.schedule(); @@ -178,15 +221,25 @@ impl Deterministic { let woken = Arc::new(AtomicBool::new(false)); let state = self.state.clone(); + let id; + { + let mut state = state.lock(); + id = util::post_inc(&mut state.next_runnable_id); + state + .runnable_backtraces + .insert(id, backtrace::Backtrace::new()); + } + let unparker = self.parker.lock().unparker(); let (runnable, mut main_task) = unsafe { async_task::spawn_unchecked(main_future, move |runnable| { - let mut state = state.lock(); + let state = &mut *state.lock(); state .scheduled_from_foreground .entry(cx_id) .or_default() .push(ForegroundRunnable { + id: util::post_inc(&mut state.next_runnable_id), runnable, main: true, }); @@ -248,9 +301,10 @@ impl Deterministic { if !state.scheduled_from_background.is_empty() && state.rng.gen() { let background_len = state.scheduled_from_background.len(); let ix = state.rng.gen_range(0..background_len); - let runnable = state.scheduled_from_background.remove(ix); + let background_runnable = state.scheduled_from_background.remove(ix); + state.poll_history.push(background_runnable.id); drop(state); - runnable.run(); + background_runnable.runnable.run(); } else if !state.scheduled_from_foreground.is_empty() { let available_cx_ids = state .scheduled_from_foreground @@ -266,6 +320,7 @@ impl Deterministic { if scheduled_from_cx.is_empty() { state.scheduled_from_foreground.remove(&cx_id_to_run); } + state.poll_history.push(foreground_runnable.id); drop(state); @@ -298,9 +353,10 @@ impl Deterministic { let runnable_count = state.scheduled_from_background.len(); let ix = state.rng.gen_range(0..=runnable_count); if ix < state.scheduled_from_background.len() { - let runnable = state.scheduled_from_background.remove(ix); + let background_runnable = state.scheduled_from_background.remove(ix); + state.poll_history.push(background_runnable.id); drop(state); - runnable.run(); + background_runnable.runnable.run(); } else { drop(state); if let Poll::Ready(result) = future.poll(&mut cx) { diff --git a/crates/gpui/src/test.rs b/crates/gpui/src/test.rs index e76b094c9a586951ea0bab55ff3e058553635535..665033a71c13fa16bef5f5ebad91fcf98b9d4e3d 100644 --- a/crates/gpui/src/test.rs +++ b/crates/gpui/src/test.rs @@ -1,11 +1,13 @@ use crate::{ - elements::Empty, executor, platform, Element, ElementBox, Entity, FontCache, Handle, - LeakDetector, MutableAppContext, Platform, RenderContext, Subscription, TestAppContext, View, + elements::Empty, executor, platform, util::CwdBacktrace, Element, ElementBox, Entity, + FontCache, Handle, LeakDetector, MutableAppContext, Platform, RenderContext, Subscription, + TestAppContext, View, }; use futures::StreamExt; use parking_lot::Mutex; use smol::channel; use std::{ + fmt::Write, panic::{self, RefUnwindSafe}, rc::Rc, sync::{ @@ -29,13 +31,13 @@ pub fn run_test( mut num_iterations: u64, mut starting_seed: u64, max_retries: usize, + detect_nondeterminism: bool, test_fn: &mut (dyn RefUnwindSafe + Fn( &mut MutableAppContext, Rc, Arc, u64, - bool, )), fn_name: String, ) { @@ -60,10 +62,10 @@ pub fn run_test( let platform = Arc::new(platform::test::platform()); let font_system = platform.fonts(); let font_cache = Arc::new(FontCache::new(font_system)); + let mut prev_runnable_history: Option> = None; - loop { - let seed = atomic_seed.fetch_add(1, SeqCst); - let is_last_iteration = seed + 1 >= starting_seed + num_iterations; + for _ in 0..num_iterations { + let seed = atomic_seed.load(SeqCst); if is_randomized { dbg!(seed); @@ -82,13 +84,7 @@ pub fn run_test( fn_name.clone(), ); cx.update(|cx| { - test_fn( - cx, - foreground_platform.clone(), - deterministic.clone(), - seed, - is_last_iteration, - ); + test_fn(cx, foreground_platform.clone(), deterministic.clone(), seed); }); cx.update(|cx| cx.remove_all_windows()); @@ -96,8 +92,64 @@ pub fn run_test( cx.update(|cx| cx.clear_globals()); leak_detector.lock().detect(); - if is_last_iteration { - break; + + if detect_nondeterminism { + let curr_runnable_history = deterministic.runnable_history(); + if let Some(prev_runnable_history) = prev_runnable_history { + let mut prev_entries = prev_runnable_history.iter().fuse(); + let mut curr_entries = curr_runnable_history.iter().fuse(); + + let mut nondeterministic = false; + let mut common_history_prefix = Vec::new(); + let mut prev_history_suffix = Vec::new(); + let mut curr_history_suffix = Vec::new(); + loop { + match (prev_entries.next(), curr_entries.next()) { + (None, None) => break, + (None, Some(curr_id)) => curr_history_suffix.push(*curr_id), + (Some(prev_id), None) => prev_history_suffix.push(*prev_id), + (Some(prev_id), Some(curr_id)) => { + if nondeterministic { + prev_history_suffix.push(*prev_id); + curr_history_suffix.push(*curr_id); + } else if prev_id == curr_id { + common_history_prefix.push(*curr_id); + } else { + nondeterministic = true; + prev_history_suffix.push(*prev_id); + curr_history_suffix.push(*curr_id); + } + } + } + } + + if nondeterministic { + let mut error = String::new(); + writeln!(&mut error, "Common prefix: {:?}", common_history_prefix) + .unwrap(); + writeln!(&mut error, "Previous suffix: {:?}", prev_history_suffix) + .unwrap(); + writeln!(&mut error, "Current suffix: {:?}", curr_history_suffix) + .unwrap(); + + let last_common_backtrace = common_history_prefix + .last() + .map(|runnable_id| deterministic.runnable_backtrace(*runnable_id)); + + writeln!( + &mut error, + "Last future that ran on both executions: {:?}", + last_common_backtrace.as_ref().map(CwdBacktrace) + ) + .unwrap(); + panic!("Detected non-determinism.\n{}", error); + } + } + prev_runnable_history = Some(curr_runnable_history); + } + + if !detect_nondeterminism { + atomic_seed.fetch_add(1, SeqCst); } } }); @@ -112,7 +164,7 @@ pub fn run_test( println!("retrying: attempt {}", retries); } else { if is_randomized { - eprintln!("failing seed: {}", atomic_seed.load(SeqCst) - 1); + eprintln!("failing seed: {}", atomic_seed.load(SeqCst)); } panic::resume_unwind(error); } diff --git a/crates/gpui_macros/src/gpui_macros.rs b/crates/gpui_macros/src/gpui_macros.rs index b43bedc64315ffc9f1b7b98845e064e6cc67555d..e28d1711d2ceda9cfcedf28310575e1cbc3cc620 100644 --- a/crates/gpui_macros/src/gpui_macros.rs +++ b/crates/gpui_macros/src/gpui_macros.rs @@ -14,6 +14,7 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream { let mut max_retries = 0; let mut num_iterations = 1; let mut starting_seed = 0; + let mut detect_nondeterminism = false; for arg in args { match arg { @@ -26,6 +27,9 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream { let key_name = meta.path.get_ident().map(|i| i.to_string()); let result = (|| { match key_name.as_deref() { + Some("detect_nondeterminism") => { + detect_nondeterminism = parse_bool(&meta.lit)? + } Some("retries") => max_retries = parse_int(&meta.lit)?, Some("iterations") => num_iterations = parse_int(&meta.lit)?, Some("seed") => starting_seed = parse_int(&meta.lit)?, @@ -77,10 +81,6 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream { inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed),)); continue; } - Some("bool") => { - inner_fn_args.extend(quote!(is_last_iteration,)); - continue; - } Some("Arc") => { if let syn::PathArguments::AngleBracketed(args) = &last_segment.unwrap().arguments @@ -146,7 +146,8 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream { #num_iterations as u64, #starting_seed as u64, #max_retries, - &mut |cx, foreground_platform, deterministic, seed, is_last_iteration| { + #detect_nondeterminism, + &mut |cx, foreground_platform, deterministic, seed| { #cx_vars cx.foreground().run(#inner_fn_name(#inner_fn_args)); #cx_teardowns @@ -165,9 +166,6 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream { Some("StdRng") => { inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed),)); } - Some("bool") => { - inner_fn_args.extend(quote!(is_last_iteration,)); - } _ => {} } } else { @@ -189,7 +187,8 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream { #num_iterations as u64, #starting_seed as u64, #max_retries, - &mut |cx, _, _, seed, is_last_iteration| #inner_fn_name(#inner_fn_args), + #detect_nondeterminism, + &mut |cx, _, _, seed| #inner_fn_name(#inner_fn_args), stringify!(#outer_fn_name).to_string(), ); } @@ -209,3 +208,13 @@ fn parse_int(literal: &Lit) -> Result { result.map_err(|err| TokenStream::from(err.into_compile_error())) } + +fn parse_bool(literal: &Lit) -> Result { + let result = if let Lit::Bool(result) = &literal { + Ok(result.value) + } else { + Err(syn::Error::new(literal.span(), "must be a boolean")) + }; + + result.map_err(|err| TokenStream::from(err.into_compile_error())) +} From d0709e7bfa53d128aaeb3b7dab49d28dd735f7ce Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 28 Nov 2022 19:18:31 +0100 Subject: [PATCH 067/109] Error if project is disconnected after getting completions response --- crates/project/src/project.rs | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index fb77da9347db5bc3099cfc11f8461994c1becb43..a3439430fdea9290638a158a9a3788374a4280da 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -3408,19 +3408,29 @@ impl Project { position: Some(language::proto::serialize_anchor(&anchor)), version: serialize_version(&source_buffer.version()), }; - cx.spawn_weak(|_, mut cx| async move { + cx.spawn_weak(|this, mut cx| async move { let response = rpc.request(message).await?; - source_buffer_handle - .update(&mut cx, |buffer, _| { - buffer.wait_for_version(deserialize_version(response.version)) - }) - .await; + if this + .upgrade(&cx) + .ok_or_else(|| anyhow!("project was dropped"))? + .read_with(&cx, |this, _| this.is_read_only()) + { + return Err(anyhow!( + "failed to get completions: project was disconnected" + )); + } else { + source_buffer_handle + .update(&mut cx, |buffer, _| { + buffer.wait_for_version(deserialize_version(response.version)) + }) + .await; - let completions = response.completions.into_iter().map(|completion| { - language::proto::deserialize_completion(completion, language.clone()) - }); - futures::future::try_join_all(completions).await + let completions = response.completions.into_iter().map(|completion| { + language::proto::deserialize_completion(completion, language.clone()) + }); + futures::future::try_join_all(completions).await + } }) } else { Task::ready(Ok(Default::default())) From cd2a8579b9dbd2ed2023d4da2d24b6219861c25e Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 28 Nov 2022 19:35:33 +0100 Subject: [PATCH 068/109] Capture runnable backtraces only when detecting nondeterminism --- crates/gpui/src/executor.rs | 30 +++++++++++++++++++++--------- crates/gpui/src/test.rs | 4 ++++ 2 files changed, 25 insertions(+), 9 deletions(-) diff --git a/crates/gpui/src/executor.rs b/crates/gpui/src/executor.rs index 5231f8a51a0c6f5dbb25a299944251ccca125e08..876e48351d6e8e224df3dcefd2a953414b4436b9 100644 --- a/crates/gpui/src/executor.rs +++ b/crates/gpui/src/executor.rs @@ -75,6 +75,7 @@ struct DeterministicState { waiting_backtrace: Option, next_runnable_id: usize, poll_history: Vec, + enable_runnable_backtraces: bool, runnable_backtraces: collections::HashMap, } @@ -129,6 +130,7 @@ impl Deterministic { waiting_backtrace: None, next_runnable_id: 0, poll_history: Default::default(), + enable_runnable_backtraces: false, runnable_backtraces: Default::default(), })), parker: Default::default(), @@ -139,6 +141,10 @@ impl Deterministic { self.state.lock().poll_history.clone() } + pub fn enable_runnable_backtrace(&self) { + self.state.lock().enable_runnable_backtraces = true; + } + pub fn runnable_backtrace(&self, runnable_id: usize) -> backtrace::Backtrace { let mut backtrace = self.state.lock().runnable_backtraces[&runnable_id].clone(); backtrace.resolve(); @@ -169,9 +175,11 @@ impl Deterministic { { let mut state = state.lock(); id = util::post_inc(&mut state.next_runnable_id); - state - .runnable_backtraces - .insert(id, backtrace::Backtrace::new_unresolved()); + if state.enable_runnable_backtraces { + state + .runnable_backtraces + .insert(id, backtrace::Backtrace::new_unresolved()); + } } let unparker = self.parker.lock().unparker(); @@ -194,9 +202,11 @@ impl Deterministic { { let mut state = state.lock(); id = util::post_inc(&mut state.next_runnable_id); - state - .runnable_backtraces - .insert(id, backtrace::Backtrace::new_unresolved()); + if state.enable_runnable_backtraces { + state + .runnable_backtraces + .insert(id, backtrace::Backtrace::new_unresolved()); + } } let unparker = self.parker.lock().unparker(); @@ -225,9 +235,11 @@ impl Deterministic { { let mut state = state.lock(); id = util::post_inc(&mut state.next_runnable_id); - state - .runnable_backtraces - .insert(id, backtrace::Backtrace::new()); + if state.enable_runnable_backtraces { + state + .runnable_backtraces + .insert(id, backtrace::Backtrace::new_unresolved()); + } } let unparker = self.parker.lock().unparker(); diff --git a/crates/gpui/src/test.rs b/crates/gpui/src/test.rs index 665033a71c13fa16bef5f5ebad91fcf98b9d4e3d..aade1054a8d919590bded33c09dc4c458a6579e6 100644 --- a/crates/gpui/src/test.rs +++ b/crates/gpui/src/test.rs @@ -72,6 +72,10 @@ pub fn run_test( } let deterministic = executor::Deterministic::new(seed); + if detect_nondeterminism { + deterministic.enable_runnable_backtrace(); + } + let leak_detector = Arc::new(Mutex::new(LeakDetector::default())); let mut cx = TestAppContext::new( foreground_platform.clone(), From d525cfd697efae7a06e605c51f2da4703fdc484e Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 29 Nov 2022 11:02:14 +0100 Subject: [PATCH 069/109] Increase probability of creating new files in randomized test --- crates/collab/src/integration_tests.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index cf6bb8af3ad80251b1d1f5b9ddf12c577eb3977f..93ff73fc838cf961b03dcd0ca5740a64625e2bae 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -6391,7 +6391,7 @@ impl TestClient { buffers.extend(search.await?.into_keys()); } } - 60..=69 => { + 60..=79 => { let worktree = project .read_with(cx, |project, cx| { project From ac24600a4022716bc1aa4c305572b4e7141d5ec2 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 29 Nov 2022 13:55:08 +0100 Subject: [PATCH 070/109] Start moving towards using sea-query to construct queries --- Cargo.lock | 34 +++++ Cargo.toml | 1 + crates/collab/Cargo.toml | 15 +- crates/collab/src/db.rs | 134 +++++++++++------- crates/collab/src/db/schema.rs | 43 ++++++ .../collab/src/{db_tests.rs => db/tests.rs} | 2 +- crates/collab/src/main.rs | 2 - 7 files changed, 168 insertions(+), 63 deletions(-) create mode 100644 crates/collab/src/db/schema.rs rename crates/collab/src/{db_tests.rs => db/tests.rs} (99%) diff --git a/Cargo.lock b/Cargo.lock index 8cd5e7d6d7ba748271c0e230ed1b4682e1bb50dc..5083b9131266ee6987ccbf299c0ad0f86f2cd1bd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1065,6 +1065,8 @@ dependencies = [ "reqwest", "rpc", "scrypt", + "sea-query", + "sea-query-binder", "serde", "serde_json", "settings", @@ -5121,6 +5123,38 @@ dependencies = [ "untrusted", ] +[[package]] +name = "sea-query" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a4f0fc4d8e44e1d51c739a68d336252a18bc59553778075d5e32649be6ec92ed" +dependencies = [ + "sea-query-derive", +] + +[[package]] +name = "sea-query-binder" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c2585b89c985cfacfe0ec9fc9e7bb055b776c1a2581c4e3c6185af2b8bf8865" +dependencies = [ + "sea-query", + "sqlx", +] + +[[package]] +name = "sea-query-derive" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34cdc022b4f606353fe5dc85b09713a04e433323b70163e81513b141c6ae6eb5" +dependencies = [ + "heck 0.3.3", + "proc-macro2", + "quote", + "syn", + "thiserror", +] + [[package]] name = "seahash" version = "4.1.0" diff --git a/Cargo.toml b/Cargo.toml index 205017da1fbc156543b143fc13238780767e7734..03fcb4cfd9dc3ad5f360f25d2802681c4a8518d8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -67,6 +67,7 @@ rand = { version = "0.8" } [patch.crates-io] tree-sitter = { git = "https://github.com/tree-sitter/tree-sitter", rev = "366210ae925d7ea0891bc7a0c738f60c77c04d7b" } async-task = { git = "https://github.com/zed-industries/async-task", rev = "341b57d6de98cdfd7b418567b8de2022ca993a6e" } +sqlx = { git = "https://github.com/launchbadge/sqlx", rev = "4b7053807c705df312bcb9b6281e184bf7534eb3" } # TODO - Remove when a version is released with this PR: https://github.com/servo/core-foundation-rs/pull/457 cocoa = { git = "https://github.com/servo/core-foundation-rs", rev = "079665882507dd5e2ff77db3de5070c1f6c0fb85" } diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index e5a97b9764d2d476af4f22dff89147f2cf06698b..e854b003c8bdf5d4257cdf6bc05a7c8641ed256d 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -36,9 +36,12 @@ prometheus = "0.13" rand = "0.8" reqwest = { version = "0.11", features = ["json"], optional = true } scrypt = "0.7" +sea-query = { version = "0.27", features = ["derive"] } +sea-query-binder = { version = "0.2", features = ["sqlx-postgres"] } serde = { version = "1.0", features = ["derive", "rc"] } serde_json = "1.0" sha-1 = "0.9" +sqlx = { version = "0.6", features = ["runtime-tokio-rustls", "postgres", "json", "time", "uuid"] } time = { version = "0.3", features = ["serde", "serde-well-known"] } tokio = { version = "1", features = ["full"] } tokio-tungstenite = "0.17" @@ -49,11 +52,6 @@ tracing = "0.1.34" tracing-log = "0.1.3" tracing-subscriber = { version = "0.3.11", features = ["env-filter", "json"] } -[dependencies.sqlx] -git = "https://github.com/launchbadge/sqlx" -rev = "4b7053807c705df312bcb9b6281e184bf7534eb3" -features = ["runtime-tokio-rustls", "postgres", "json", "time", "uuid"] - [dev-dependencies] collections = { path = "../collections", features = ["test-support"] } gpui = { path = "../gpui", features = ["test-support"] } @@ -76,13 +74,10 @@ env_logger = "0.9" log = { version = "0.4.16", features = ["kv_unstable_serde"] } util = { path = "../util" } lazy_static = "1.4" +sea-query-binder = { version = "0.2", features = ["sqlx-sqlite"] } serde_json = { version = "1.0", features = ["preserve_order"] } +sqlx = { version = "0.6", features = ["sqlite"] } unindent = "0.1" -[dev-dependencies.sqlx] -git = "https://github.com/launchbadge/sqlx" -rev = "4b7053807c705df312bcb9b6281e184bf7534eb3" -features = ["sqlite"] - [features] seed-support = ["clap", "lipsum", "reqwest"] diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index eff97855c6a601992cc8c2a96d953e063b0a6cb5..044d4ef8d7790f48491e0d4797080f78073662ce 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1,3 +1,7 @@ +mod schema; +#[cfg(test)] +mod tests; + use crate::{Error, Result}; use anyhow::anyhow; use axum::http::StatusCode; @@ -5,6 +9,8 @@ use collections::{BTreeMap, HashMap, HashSet}; use dashmap::DashMap; use futures::{future::BoxFuture, FutureExt, StreamExt}; use rpc::{proto, ConnectionId}; +use sea_query::{Expr, Query}; +use sea_query_binder::SqlxBinder; use serde::{Deserialize, Serialize}; use sqlx::{ migrate::{Migrate as _, Migration, MigrationSource}, @@ -89,6 +95,23 @@ impl BeginTransaction for Db { } } +pub trait BuildQuery { + fn build_query(&self, query: &T) -> (String, sea_query_binder::SqlxValues); +} + +impl BuildQuery for Db { + fn build_query(&self, query: &T) -> (String, sea_query_binder::SqlxValues) { + query.build_sqlx(sea_query::PostgresQueryBuilder) + } +} + +#[cfg(test)] +impl BuildQuery for Db { + fn build_query(&self, query: &T) -> (String, sea_query_binder::SqlxValues) { + query.build_sqlx(sea_query::SqliteQueryBuilder) + } +} + pub trait RowsAffected { fn rows_affected(&self) -> u64; } @@ -595,10 +618,11 @@ impl Db { impl Db where - Self: BeginTransaction, + Self: BeginTransaction + BuildQuery, D: sqlx::Database + sqlx::migrate::MigrateDatabase, D::Connection: sqlx::migrate::Migrate, for<'a> >::Arguments: sqlx::IntoArguments<'a, D>, + for<'a> sea_query_binder::SqlxValues: sqlx::IntoArguments<'a, D>, for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>, for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>, D::QueryResult: RowsAffected, @@ -1537,63 +1561,66 @@ where worktrees: &[proto::WorktreeMetadata], ) -> Result> { self.transact(|mut tx| async move { - let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>( - " - SELECT room_id, user_id - FROM room_participants - WHERE answering_connection_id = $1 - ", - ) - .bind(connection_id.0 as i32) - .fetch_one(&mut tx) - .await?; + let (sql, values) = self.build_query( + Query::select() + .columns([ + schema::room_participant::Definition::RoomId, + schema::room_participant::Definition::UserId, + ]) + .from(schema::room_participant::Definition::Table) + .and_where( + Expr::col(schema::room_participant::Definition::AnsweringConnectionId) + .eq(connection_id.0), + ), + ); + let (room_id, user_id) = sqlx::query_as_with::<_, (RoomId, UserId), _>(&sql, values) + .fetch_one(&mut tx) + .await?; if room_id != expected_room_id { return Err(anyhow!("shared project on unexpected room"))?; } - let project_id: ProjectId = sqlx::query_scalar( - " - INSERT INTO projects (room_id, host_user_id, host_connection_id) - VALUES ($1, $2, $3) - RETURNING id - ", - ) - .bind(room_id) - .bind(user_id) - .bind(connection_id.0 as i32) - .fetch_one(&mut tx) - .await?; + let (sql, values) = self.build_query( + Query::insert() + .into_table(schema::project::Definition::Table) + .columns([ + schema::project::Definition::RoomId, + schema::project::Definition::HostUserId, + schema::project::Definition::HostConnectionId, + ]) + .values_panic([room_id.into(), user_id.into(), connection_id.0.into()]) + .returning_col(schema::project::Definition::Id), + ); + let project_id: ProjectId = sqlx::query_scalar_with(&sql, values) + .fetch_one(&mut tx) + .await?; if !worktrees.is_empty() { - let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len()); - params.pop(); - let query = format!( - " - INSERT INTO worktrees ( - project_id, - id, - root_name, - abs_path, - visible, - scan_id, - is_complete - ) - VALUES {params} - " - ); - - let mut query = sqlx::query(&query); + let mut query = Query::insert() + .into_table(schema::worktree::Definition::Table) + .columns([ + schema::worktree::Definition::ProjectId, + schema::worktree::Definition::Id, + schema::worktree::Definition::RootName, + schema::worktree::Definition::AbsPath, + schema::worktree::Definition::Visible, + schema::worktree::Definition::ScanId, + schema::worktree::Definition::IsComplete, + ]) + .to_owned(); for worktree in worktrees { - query = query - .bind(project_id) - .bind(worktree.id as i32) - .bind(&worktree.root_name) - .bind(&worktree.abs_path) - .bind(worktree.visible) - .bind(0) - .bind(false); + query.values_panic([ + project_id.into(), + worktree.id.into(), + worktree.root_name.clone().into(), + worktree.abs_path.clone().into(), + worktree.visible.into(), + 0.into(), + false.into(), + ]); } - query.execute(&mut tx).await?; + let (sql, values) = self.build_query(&query); + sqlx::query_with(&sql, values).execute(&mut tx).await?; } sqlx::query( @@ -2648,6 +2675,12 @@ macro_rules! id_type { self.0.fmt(f) } } + + impl From<$name> for sea_query::Value { + fn from(value: $name) -> Self { + sea_query::Value::Int(Some(value.0)) + } + } }; } @@ -2692,6 +2725,7 @@ id_type!(WorktreeId); #[derive(Clone, Debug, Default, FromRow, PartialEq)] struct WorktreeRow { pub id: WorktreeId, + pub project_id: ProjectId, pub abs_path: String, pub root_name: String, pub visible: bool, diff --git a/crates/collab/src/db/schema.rs b/crates/collab/src/db/schema.rs new file mode 100644 index 0000000000000000000000000000000000000000..40a3e334d19bf483302beab702ca4038500d0138 --- /dev/null +++ b/crates/collab/src/db/schema.rs @@ -0,0 +1,43 @@ +pub mod project { + use sea_query::Iden; + + #[derive(Iden)] + pub enum Definition { + #[iden = "projects"] + Table, + Id, + RoomId, + HostUserId, + HostConnectionId, + } +} + +pub mod worktree { + use sea_query::Iden; + + #[derive(Iden)] + pub enum Definition { + #[iden = "worktrees"] + Table, + Id, + ProjectId, + AbsPath, + RootName, + Visible, + ScanId, + IsComplete, + } +} + +pub mod room_participant { + use sea_query::Iden; + + #[derive(Iden)] + pub enum Definition { + #[iden = "room_participants"] + Table, + RoomId, + UserId, + AnsweringConnectionId, + } +} diff --git a/crates/collab/src/db_tests.rs b/crates/collab/src/db/tests.rs similarity index 99% rename from crates/collab/src/db_tests.rs rename to crates/collab/src/db/tests.rs index 444e60ddeb0c5e03df39e132189eac9ecca46033..88488b10d26fda779611d698e608abcabc6ca688 100644 --- a/crates/collab/src/db_tests.rs +++ b/crates/collab/src/db/tests.rs @@ -1,4 +1,4 @@ -use super::db::*; +use super::*; use gpui::executor::{Background, Deterministic}; use std::sync::Arc; diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 20fae38c161e01fd325a05cd2868f437ccef5363..019197fc46e90bf83754014b36bc3394055e1e3d 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -4,8 +4,6 @@ mod db; mod env; mod rpc; -#[cfg(test)] -mod db_tests; #[cfg(test)] mod integration_tests; From 11a39226e8491a0774c19cd83b84918d2906fa86 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 29 Nov 2022 16:49:04 +0100 Subject: [PATCH 071/109] Start on a new `db2` module that uses SeaORM --- Cargo.lock | 280 ++++++++++++++++ crates/collab/Cargo.toml | 2 + .../20221109000000_test_schema.sql | 2 +- crates/collab/src/db2.rs | 316 ++++++++++++++++++ crates/collab/src/db2/project.rs | 37 ++ crates/collab/src/db2/project_collaborator.rs | 18 + crates/collab/src/db2/room.rs | 31 ++ crates/collab/src/db2/room_participant.rs | 34 ++ crates/collab/src/db2/worktree.rs | 33 ++ crates/collab/src/lib.rs | 12 + crates/collab/src/main.rs | 1 + 11 files changed, 765 insertions(+), 1 deletion(-) create mode 100644 crates/collab/src/db2.rs create mode 100644 crates/collab/src/db2/project.rs create mode 100644 crates/collab/src/db2/project_collaborator.rs create mode 100644 crates/collab/src/db2/room.rs create mode 100644 crates/collab/src/db2/room_participant.rs create mode 100644 crates/collab/src/db2/worktree.rs diff --git a/Cargo.lock b/Cargo.lock index 5083b9131266ee6987ccbf299c0ad0f86f2cd1bd..7b09775f2a46bad44cfcd2d98645bad8640828e1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,12 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "Inflector" +version = "0.11.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe438c63458706e03479442743baae6c88256498e6431708f6dfc520a26515d3" + [[package]] name = "activity_indicator" version = "0.1.0" @@ -107,6 +113,12 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "aliasable" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "250f629c0161ad8107cf89319e990051fae62832fd343083bea452d93e2205fd" + [[package]] name = "ambient-authority" version = "0.0.1" @@ -547,6 +559,19 @@ dependencies = [ "rustc-demangle", ] +[[package]] +name = "bae" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33b8de67cc41132507eeece2584804efcb15f85ba516e34c944b7667f480397a" +dependencies = [ + "heck 0.3.3", + "proc-macro-error", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "base64" version = "0.13.0" @@ -635,6 +660,51 @@ dependencies = [ "once_cell", ] +[[package]] +name = "borsh" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15bf3650200d8bffa99015595e10f1fbd17de07abbc25bb067da79e769939bfa" +dependencies = [ + "borsh-derive", + "hashbrown 0.11.2", +] + +[[package]] +name = "borsh-derive" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6441c552f230375d18e3cc377677914d2ca2b0d36e52129fe15450a2dce46775" +dependencies = [ + "borsh-derive-internal", + "borsh-schema-derive-internal", + "proc-macro-crate", + "proc-macro2", + "syn", +] + +[[package]] +name = "borsh-derive-internal" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5449c28a7b352f2d1e592a8a28bf139bc71afb0764a14f3c02500935d8c44065" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "borsh-schema-derive-internal" +version = "0.9.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdbd5696d8bfa21d53d9fe39a714a18538bad11492a42d066dbbc395fb1951c0" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "breadcrumbs" version = "0.1.0" @@ -678,6 +748,27 @@ version = "3.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c1ad822118d20d2c234f427000d5acc36eabe1e29a348c89b63dd60b13f28e5d" +[[package]] +name = "bytecheck" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d11cac2c12b5adc6570dad2ee1b87eff4955dac476fe12d81e5fdd352e52406f" +dependencies = [ + "bytecheck_derive", + "ptr_meta", +] + +[[package]] +name = "bytecheck_derive" +version = "0.6.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13e576ebe98e605500b3c8041bb888e966653577172df6dd97398714eb30b9bf" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "bytemuck" version = "1.12.1" @@ -841,6 +932,7 @@ dependencies = [ "js-sys", "num-integer", "num-traits", + "serde", "time 0.1.44", "wasm-bindgen", "winapi 0.3.9", @@ -1065,6 +1157,7 @@ dependencies = [ "reqwest", "rpc", "scrypt", + "sea-orm", "sea-query", "sea-query-binder", "serde", @@ -3843,6 +3936,29 @@ version = "6.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ff7415e9ae3fff1225851df9e0d9e4e5479f947619774677a63572e55e80eff" +[[package]] +name = "ouroboros" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfbb50b356159620db6ac971c6d5c9ab788c9cc38a6f49619fca2a27acb062ca" +dependencies = [ + "aliasable", + "ouroboros_macro", +] + +[[package]] +name = "ouroboros_macro" +version = "0.15.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a0d9d1a6191c4f391f87219d1ea42b23f09ee84d64763cd05ee6ea88d9f384d" +dependencies = [ + "Inflector", + "proc-macro-error", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "outline" version = "0.1.0" @@ -4201,6 +4317,15 @@ version = "0.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eb9f9e6e233e5c4a35559a617bf40a4ec447db2e84c20b55a6f83167b7e57872" +[[package]] +name = "proc-macro-crate" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d6ea3c4595b96363c13943497db34af4460fb474a95c43f4446ad341b8c9785" +dependencies = [ + "toml", +] + [[package]] name = "proc-macro-error" version = "1.0.4" @@ -4446,6 +4571,26 @@ dependencies = [ "cc", ] +[[package]] +name = "ptr_meta" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0738ccf7ea06b608c10564b31debd4f5bc5e197fc8bfe088f68ae5ce81e7a4f1" +dependencies = [ + "ptr_meta_derive", +] + +[[package]] +name = "ptr_meta_derive" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "16b845dbfca988fa33db069c0e230574d15a3088f147a87b64c7589eb662c9ac" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "pulldown-cmark" version = "0.9.2" @@ -4683,6 +4828,15 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "rend" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79af64b4b6362ffba04eef3a4e10829718a4896dac19daa741851c86781edf95" +dependencies = [ + "bytecheck", +] + [[package]] name = "reqwest" version = "0.11.12" @@ -4760,6 +4914,31 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "rkyv" +version = "0.7.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cec2b3485b07d96ddfd3134767b8a447b45ea4eb91448d0a35180ec0ffd5ed15" +dependencies = [ + "bytecheck", + "hashbrown 0.12.3", + "ptr_meta", + "rend", + "rkyv_derive", + "seahash", +] + +[[package]] +name = "rkyv_derive" +version = "0.7.39" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6eaedadc88b53e36dd32d940ed21ae4d850d5916f2581526921f553a72ac34c4" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "rmp" version = "0.8.11" @@ -4911,6 +5090,24 @@ dependencies = [ "walkdir", ] +[[package]] +name = "rust_decimal" +version = "1.27.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33c321ee4e17d2b7abe12b5d20c1231db708dd36185c8a21e9de5fed6da4dbe9" +dependencies = [ + "arrayvec 0.7.2", + "borsh", + "bytecheck", + "byteorder", + "bytes 1.2.1", + "num-traits", + "rand 0.8.5", + "rkyv", + "serde", + "serde_json", +] + [[package]] name = "rustc-demangle" version = "0.1.21" @@ -4982,6 +5179,12 @@ dependencies = [ "base64", ] +[[package]] +name = "rustversion" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97477e48b4cf8603ad5f7aaf897467cf42ab4218a38ef76fb14c2d6773a6d6a8" + [[package]] name = "rustybuzz" version = "0.3.0" @@ -5123,13 +5326,59 @@ dependencies = [ "untrusted", ] +[[package]] +name = "sea-orm" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3120bc435b8640963ffda698f877610e07e077157e216eb99408d819c344034d" +dependencies = [ + "async-stream", + "async-trait", + "chrono", + "futures 0.3.24", + "futures-util", + "log", + "ouroboros", + "rust_decimal", + "sea-orm-macros", + "sea-query", + "sea-query-binder", + "sea-strum", + "serde", + "serde_json", + "sqlx", + "thiserror", + "time 0.3.15", + "tracing", + "url", + "uuid 1.2.1", +] + +[[package]] +name = "sea-orm-macros" +version = "0.10.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c54bacfeb842813c16821e21f9456c358861a448294075184ea1d6307e386d08" +dependencies = [ + "bae", + "heck 0.3.3", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "sea-query" version = "0.27.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a4f0fc4d8e44e1d51c739a68d336252a18bc59553778075d5e32649be6ec92ed" dependencies = [ + "chrono", + "rust_decimal", "sea-query-derive", + "serde_json", + "time 0.3.15", + "uuid 1.2.1", ] [[package]] @@ -5138,8 +5387,13 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c2585b89c985cfacfe0ec9fc9e7bb055b776c1a2581c4e3c6185af2b8bf8865" dependencies = [ + "chrono", + "rust_decimal", "sea-query", + "serde_json", "sqlx", + "time 0.3.15", + "uuid 1.2.1", ] [[package]] @@ -5155,6 +5409,28 @@ dependencies = [ "thiserror", ] +[[package]] +name = "sea-strum" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "391d06a6007842cfe79ac6f7f53911b76dfd69fc9a6769f1cf6569d12ce20e1b" +dependencies = [ + "sea-strum_macros", +] + +[[package]] +name = "sea-strum_macros" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69b4397b825df6ccf1e98bcdabef3bbcfc47ff5853983467850eeab878384f21" +dependencies = [ + "heck 0.3.3", + "proc-macro2", + "quote", + "rustversion", + "syn", +] + [[package]] name = "seahash" version = "4.1.0" @@ -5670,6 +5946,7 @@ dependencies = [ "bitflags", "byteorder", "bytes 1.2.1", + "chrono", "crc", "crossbeam-queue", "dirs 4.0.0", @@ -5693,10 +5970,12 @@ dependencies = [ "log", "md-5", "memchr", + "num-bigint", "once_cell", "paste", "percent-encoding", "rand 0.8.5", + "rust_decimal", "rustls 0.20.7", "rustls-pemfile", "serde", @@ -6847,6 +7126,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "feb41e78f93363bb2df8b0e86a2ca30eed7806ea16ea0c790d757cf93f79be83" dependencies = [ "getrandom 0.2.7", + "serde", ] [[package]] diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index e854b003c8bdf5d4257cdf6bc05a7c8641ed256d..e10f9fe8dc29b4f9bbde8840a1264627e2a6a632 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -36,6 +36,7 @@ prometheus = "0.13" rand = "0.8" reqwest = { version = "0.11", features = ["json"], optional = true } scrypt = "0.7" +sea-orm = { version = "0.10", features = ["sqlx-postgres", "runtime-tokio-rustls"] } sea-query = { version = "0.27", features = ["derive"] } sea-query-binder = { version = "0.2", features = ["sqlx-postgres"] } serde = { version = "1.0", features = ["derive", "rc"] } @@ -74,6 +75,7 @@ env_logger = "0.9" log = { version = "0.4.16", features = ["kv_unstable_serde"] } util = { path = "../util" } lazy_static = "1.4" +sea-orm = { version = "0.10", features = ["sqlx-sqlite"] } sea-query-binder = { version = "0.2", features = ["sqlx-sqlite"] } serde_json = { version = "1.0", features = ["preserve_order"] } sqlx = { version = "0.6", features = ["sqlite"] } diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 02ca0c75a9d40132970cf08d8961d828e4d2f07f..65bf00e74ccfa70cccb1b80bfe7b9142450ce5a1 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -41,7 +41,7 @@ CREATE TABLE "rooms" ( CREATE TABLE "projects" ( "id" INTEGER PRIMARY KEY, - "room_id" INTEGER REFERENCES rooms (id), + "room_id" INTEGER REFERENCES rooms (id) NOT NULL, "host_user_id" INTEGER REFERENCES users (id) NOT NULL, "host_connection_id" INTEGER NOT NULL ); diff --git a/crates/collab/src/db2.rs b/crates/collab/src/db2.rs new file mode 100644 index 0000000000000000000000000000000000000000..687e93daae78599c599267dd4ebe64e8f95e7cb6 --- /dev/null +++ b/crates/collab/src/db2.rs @@ -0,0 +1,316 @@ +mod project; +mod project_collaborator; +mod room; +mod room_participant; +mod worktree; + +use crate::{Error, Result}; +use anyhow::anyhow; +use collections::HashMap; +use dashmap::DashMap; +use futures::StreamExt; +use rpc::{proto, ConnectionId}; +use sea_orm::ActiveValue; +use sea_orm::{ + entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr, + TransactionTrait, +}; +use serde::{Deserialize, Serialize}; +use std::ops::{Deref, DerefMut}; +use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc}; +use tokio::sync::{Mutex, OwnedMutexGuard}; + +pub struct Database { + pool: DatabaseConnection, + rooms: DashMap>>, + #[cfg(test)] + background: Option>, + #[cfg(test)] + runtime: Option, +} + +impl Database { + pub async fn new(url: &str, max_connections: u32) -> Result { + let mut options = ConnectOptions::new(url.into()); + options.max_connections(max_connections); + Ok(Self { + pool: sea_orm::Database::connect(options).await?, + rooms: DashMap::with_capacity(16384), + #[cfg(test)] + background: None, + #[cfg(test)] + runtime: None, + }) + } + + pub async fn share_project( + &self, + room_id: RoomId, + connection_id: ConnectionId, + worktrees: &[proto::WorktreeMetadata], + ) -> Result> { + self.transact(|tx| async move { + let participant = room_participant::Entity::find() + .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0)) + .one(&tx) + .await? + .ok_or_else(|| anyhow!("could not find participant"))?; + if participant.room_id != room_id.0 { + return Err(anyhow!("shared project on unexpected room"))?; + } + + let project = project::ActiveModel { + room_id: ActiveValue::set(participant.room_id), + host_user_id: ActiveValue::set(participant.user_id), + host_connection_id: ActiveValue::set(connection_id.0 as i32), + ..Default::default() + } + .insert(&tx) + .await?; + + worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel { + id: ActiveValue::set(worktree.id as i32), + project_id: ActiveValue::set(project.id), + abs_path: ActiveValue::set(worktree.abs_path.clone()), + root_name: ActiveValue::set(worktree.root_name.clone()), + visible: ActiveValue::set(worktree.visible), + scan_id: ActiveValue::set(0), + is_complete: ActiveValue::set(false), + })) + .exec(&tx) + .await?; + + project_collaborator::ActiveModel { + project_id: ActiveValue::set(project.id), + connection_id: ActiveValue::set(connection_id.0 as i32), + user_id: ActiveValue::set(participant.user_id), + replica_id: ActiveValue::set(0), + is_host: ActiveValue::set(true), + ..Default::default() + } + .insert(&tx) + .await?; + + let room = self.get_room(room_id, &tx).await?; + self.commit_room_transaction(room_id, tx, (ProjectId(project.id), room)) + .await + }) + .await + } + + async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result { + let db_room = room::Entity::find_by_id(room_id.0) + .one(tx) + .await? + .ok_or_else(|| anyhow!("could not find room"))?; + + let mut db_participants = db_room + .find_related(room_participant::Entity) + .stream(tx) + .await?; + let mut participants = HashMap::default(); + let mut pending_participants = Vec::new(); + while let Some(db_participant) = db_participants.next().await { + let db_participant = db_participant?; + if let Some(answering_connection_id) = db_participant.answering_connection_id { + let location = match ( + db_participant.location_kind, + db_participant.location_project_id, + ) { + (Some(0), Some(project_id)) => { + Some(proto::participant_location::Variant::SharedProject( + proto::participant_location::SharedProject { + id: project_id as u64, + }, + )) + } + (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject( + Default::default(), + )), + _ => Some(proto::participant_location::Variant::External( + Default::default(), + )), + }; + participants.insert( + answering_connection_id, + proto::Participant { + user_id: db_participant.user_id as u64, + peer_id: answering_connection_id as u32, + projects: Default::default(), + location: Some(proto::ParticipantLocation { variant: location }), + }, + ); + } else { + pending_participants.push(proto::PendingParticipant { + user_id: db_participant.user_id as u64, + calling_user_id: db_participant.calling_user_id as u64, + initial_project_id: db_participant.initial_project_id.map(|id| id as u64), + }); + } + } + + let mut db_projects = db_room + .find_related(project::Entity) + .find_with_related(worktree::Entity) + .stream(tx) + .await?; + + while let Some(row) = db_projects.next().await { + let (db_project, db_worktree) = row?; + if let Some(participant) = participants.get_mut(&db_project.host_connection_id) { + let project = if let Some(project) = participant + .projects + .iter_mut() + .find(|project| project.id as i32 == db_project.id) + { + project + } else { + participant.projects.push(proto::ParticipantProject { + id: db_project.id as u64, + worktree_root_names: Default::default(), + }); + participant.projects.last_mut().unwrap() + }; + + if let Some(db_worktree) = db_worktree { + project.worktree_root_names.push(db_worktree.root_name); + } + } + } + + Ok(proto::Room { + id: db_room.id as u64, + live_kit_room: db_room.live_kit_room, + participants: participants.into_values().collect(), + pending_participants, + }) + } + + async fn commit_room_transaction( + &self, + room_id: RoomId, + tx: DatabaseTransaction, + data: T, + ) -> Result> { + let lock = self.rooms.entry(room_id).or_default().clone(); + let _guard = lock.lock_owned().await; + tx.commit().await?; + Ok(RoomGuard { + data, + _guard, + _not_send: PhantomData, + }) + } + + async fn transact(&self, f: F) -> Result + where + F: Send + Fn(DatabaseTransaction) -> Fut, + Fut: Send + Future>, + { + let body = async { + loop { + let tx = self.pool.begin().await?; + match f(tx).await { + Ok(result) => return Ok(result), + Err(error) => match error { + Error::Database2( + DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error)) + | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)), + ) if error + .as_database_error() + .and_then(|error| error.code()) + .as_deref() + == Some("40001") => + { + // Retry (don't break the loop) + } + error @ _ => return Err(error), + }, + } + } + }; + + #[cfg(test)] + { + if let Some(background) = self.background.as_ref() { + background.simulate_random_delay().await; + } + + self.runtime.as_ref().unwrap().block_on(body) + } + + #[cfg(not(test))] + { + body.await + } + } +} + +pub struct RoomGuard { + data: T, + _guard: OwnedMutexGuard<()>, + _not_send: PhantomData>, +} + +impl Deref for RoomGuard { + type Target = T; + + fn deref(&self) -> &T { + &self.data + } +} + +impl DerefMut for RoomGuard { + fn deref_mut(&mut self) -> &mut T { + &mut self.data + } +} + +macro_rules! id_type { + ($name:ident) => { + #[derive( + Clone, + Copy, + Debug, + Default, + PartialEq, + Eq, + PartialOrd, + Ord, + Hash, + sqlx::Type, + Serialize, + Deserialize, + )] + #[sqlx(transparent)] + #[serde(transparent)] + pub struct $name(pub i32); + + impl $name { + #[allow(unused)] + pub const MAX: Self = Self(i32::MAX); + + #[allow(unused)] + pub fn from_proto(value: u64) -> Self { + Self(value as i32) + } + + #[allow(unused)] + pub fn to_proto(self) -> u64 { + self.0 as u64 + } + } + + impl std::fmt::Display for $name { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + self.0.fmt(f) + } + } + }; +} + +id_type!(UserId); +id_type!(RoomId); +id_type!(RoomParticipantId); +id_type!(ProjectId); +id_type!(WorktreeId); diff --git a/crates/collab/src/db2/project.rs b/crates/collab/src/db2/project.rs new file mode 100644 index 0000000000000000000000000000000000000000..4ae061683508bc2a1ab2ba580668bb45775f92c6 --- /dev/null +++ b/crates/collab/src/db2/project.rs @@ -0,0 +1,37 @@ +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "projects")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub room_id: i32, + pub host_user_id: i32, + pub host_connection_id: i32, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::room::Entity", + from = "Column::RoomId", + to = "super::room::Column::Id" + )] + Room, + #[sea_orm(has_many = "super::worktree::Entity")] + Worktree, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Room.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Worktree.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db2/project_collaborator.rs b/crates/collab/src/db2/project_collaborator.rs new file mode 100644 index 0000000000000000000000000000000000000000..da567eb2c23e683a1fe7b319978511985e819017 --- /dev/null +++ b/crates/collab/src/db2/project_collaborator.rs @@ -0,0 +1,18 @@ +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "project_collaborators")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub project_id: i32, + pub connection_id: i32, + pub user_id: i32, + pub replica_id: i32, + pub is_host: bool, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db2/room.rs b/crates/collab/src/db2/room.rs new file mode 100644 index 0000000000000000000000000000000000000000..18f1d234e54733355715fe46b3b5614065afe680 --- /dev/null +++ b/crates/collab/src/db2/room.rs @@ -0,0 +1,31 @@ +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "room_participants")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub live_kit_room: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm(has_many = "super::room_participant::Entity")] + RoomParticipant, + #[sea_orm(has_many = "super::project::Entity")] + Project, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::RoomParticipant.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Project.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db2/room_participant.rs b/crates/collab/src/db2/room_participant.rs new file mode 100644 index 0000000000000000000000000000000000000000..c9b7a13e07f53a8cab2b44bef2927dc280abe1c6 --- /dev/null +++ b/crates/collab/src/db2/room_participant.rs @@ -0,0 +1,34 @@ +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "room_participants")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + pub room_id: i32, + pub user_id: i32, + pub answering_connection_id: Option, + pub location_kind: Option, + pub location_project_id: Option, + pub initial_project_id: Option, + pub calling_user_id: i32, + pub calling_connection_id: i32, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::room::Entity", + from = "Column::RoomId", + to = "super::room::Column::Id" + )] + Room, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Room.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db2/worktree.rs b/crates/collab/src/db2/worktree.rs new file mode 100644 index 0000000000000000000000000000000000000000..3a630fcfc9d3002206580243129745f3a022fa44 --- /dev/null +++ b/crates/collab/src/db2/worktree.rs @@ -0,0 +1,33 @@ +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "worktrees")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: i32, + #[sea_orm(primary_key)] + pub project_id: i32, + pub abs_path: String, + pub root_name: String, + pub visible: bool, + pub scan_id: i64, + pub is_complete: bool, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::project::Entity", + from = "Column::ProjectId", + to = "super::project::Column::Id" + )] + Project, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Project.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index be21999a4567f385143bfeaba05101a7cd185ce5..23af3344b55656781ea735d81287213186508c94 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -5,6 +5,7 @@ pub type Result = std::result::Result; pub enum Error { Http(StatusCode, String), Database(sqlx::Error), + Database2(sea_orm::error::DbErr), Internal(anyhow::Error), } @@ -20,6 +21,12 @@ impl From for Error { } } +impl From for Error { + fn from(error: sea_orm::error::DbErr) -> Self { + Self::Database2(error) + } +} + impl From for Error { fn from(error: axum::Error) -> Self { Self::Internal(error.into()) @@ -45,6 +52,9 @@ impl IntoResponse for Error { Error::Database(error) => { (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() } + Error::Database2(error) => { + (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() + } Error::Internal(error) => { (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() } @@ -57,6 +67,7 @@ impl std::fmt::Debug for Error { match self { Error::Http(code, message) => (code, message).fmt(f), Error::Database(error) => error.fmt(f), + Error::Database2(error) => error.fmt(f), Error::Internal(error) => error.fmt(f), } } @@ -67,6 +78,7 @@ impl std::fmt::Display for Error { match self { Error::Http(code, message) => write!(f, "{code}: {message}"), Error::Database(error) => error.fmt(f), + Error::Database2(error) => error.fmt(f), Error::Internal(error) => error.fmt(f), } } diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 019197fc46e90bf83754014b36bc3394055e1e3d..8a2cdc980fbd80f62aa57a7534ab6d9ae3f61f41 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -1,6 +1,7 @@ mod api; mod auth; mod db; +mod db2; mod env; mod rpc; From b7294887c7c2f02c8730c8b662720d02a590cbb0 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 29 Nov 2022 19:20:11 +0100 Subject: [PATCH 072/109] WIP: move to a non-generic test database struct Co-Authored-By: Mikayla Maki Co-Authored-By: Julia Risley --- crates/collab/Cargo.toml | 2 +- crates/collab/src/db2.rs | 180 +++++++- crates/collab/src/db2/tests.rs | 808 +++++++++++++++++++++++++++++++++ crates/collab/src/db2/user.rs | 21 + 4 files changed, 1009 insertions(+), 2 deletions(-) create mode 100644 crates/collab/src/db2/tests.rs create mode 100644 crates/collab/src/db2/user.rs diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index e10f9fe8dc29b4f9bbde8840a1264627e2a6a632..a268bdd7b096b9c9ce22aea4ea30b09485b8446b 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -42,7 +42,7 @@ sea-query-binder = { version = "0.2", features = ["sqlx-postgres"] } serde = { version = "1.0", features = ["derive", "rc"] } serde_json = "1.0" sha-1 = "0.9" -sqlx = { version = "0.6", features = ["runtime-tokio-rustls", "postgres", "json", "time", "uuid"] } +sqlx = { version = "0.6", features = ["runtime-tokio-rustls", "postgres", "json", "time", "uuid", "any"] } time = { version = "0.3", features = ["serde", "serde-well-known"] } tokio = { version = "1", features = ["full"] } tokio-tungstenite = "0.17" diff --git a/crates/collab/src/db2.rs b/crates/collab/src/db2.rs index 687e93daae78599c599267dd4ebe64e8f95e7cb6..765fea315df706cd10cb0905497ea7c82b4ea9cb 100644 --- a/crates/collab/src/db2.rs +++ b/crates/collab/src/db2.rs @@ -2,6 +2,9 @@ mod project; mod project_collaborator; mod room; mod room_participant; +#[cfg(test)] +mod tests; +mod user; mod worktree; use crate::{Error, Result}; @@ -16,11 +19,18 @@ use sea_orm::{ TransactionTrait, }; use serde::{Deserialize, Serialize}; +use sqlx::migrate::{Migrate, Migration, MigrationSource}; +use sqlx::Connection; use std::ops::{Deref, DerefMut}; +use std::path::Path; +use std::time::Duration; use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc}; use tokio::sync::{Mutex, OwnedMutexGuard}; +pub use user::Model as User; + pub struct Database { + url: String, pool: DatabaseConnection, rooms: DashMap>>, #[cfg(test)] @@ -32,8 +42,9 @@ pub struct Database { impl Database { pub async fn new(url: &str, max_connections: u32) -> Result { let mut options = ConnectOptions::new(url.into()); - options.max_connections(max_connections); + options.min_connections(1).max_connections(max_connections); Ok(Self { + url: url.into(), pool: sea_orm::Database::connect(options).await?, rooms: DashMap::with_capacity(16384), #[cfg(test)] @@ -43,6 +54,59 @@ impl Database { }) } + pub async fn migrate( + &self, + migrations_path: &Path, + ignore_checksum_mismatch: bool, + ) -> anyhow::Result> { + let migrations = MigrationSource::resolve(migrations_path) + .await + .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?; + + let mut connection = sqlx::AnyConnection::connect(&self.url).await?; + + connection.ensure_migrations_table().await?; + let applied_migrations: HashMap<_, _> = connection + .list_applied_migrations() + .await? + .into_iter() + .map(|m| (m.version, m)) + .collect(); + + let mut new_migrations = Vec::new(); + for migration in migrations { + match applied_migrations.get(&migration.version) { + Some(applied_migration) => { + if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch + { + Err(anyhow!( + "checksum mismatch for applied migration {}", + migration.description + ))?; + } + } + None => { + let elapsed = connection.apply(&migration).await?; + new_migrations.push((migration, elapsed)); + } + } + } + + Ok(new_migrations) + } + + pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { + let ids = ids.iter().map(|id| id.0).collect::>(); + self.transact(|tx| async { + let tx = tx; + Ok(user::Entity::find() + .filter(user::Column::Id.is_in(ids.iter().copied())) + .all(&tx) + .await?) + }) + .await + } + pub async fn share_project( &self, room_id: RoomId, @@ -266,6 +330,29 @@ impl DerefMut for RoomGuard { } } +#[derive(Debug, Serialize, Deserialize)] +pub struct NewUserParams { + pub github_login: String, + pub github_user_id: i32, + pub invite_count: i32, +} + +#[derive(Debug)] +pub struct NewUserResult { + pub user_id: UserId, + pub metrics_id: String, + pub inviting_user_id: Option, + pub signup_device_id: Option, +} + +fn random_invite_code() -> String { + nanoid::nanoid!(16) +} + +fn random_email_confirmation_code() -> String { + nanoid::nanoid!(64) +} + macro_rules! id_type { ($name:ident) => { #[derive( @@ -314,3 +401,94 @@ id_type!(RoomId); id_type!(RoomParticipantId); id_type!(ProjectId); id_type!(WorktreeId); + +#[cfg(test)] +pub use test::*; + +#[cfg(test)] +mod test { + use super::*; + use gpui::executor::Background; + use lazy_static::lazy_static; + use parking_lot::Mutex; + use rand::prelude::*; + use sqlx::migrate::MigrateDatabase; + use std::sync::Arc; + + pub struct TestDb { + pub db: Option>, + } + + impl TestDb { + pub fn sqlite(background: Arc) -> Self { + let mut rng = StdRng::from_entropy(); + let url = format!("sqlite://file:zed-test-{}?mode=memory", rng.gen::()); + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .unwrap(); + + let mut db = runtime.block_on(async { + let db = Database::new(&url, 5).await.unwrap(); + let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite"); + db.migrate(migrations_path.as_ref(), false).await.unwrap(); + db + }); + + db.background = Some(background); + db.runtime = Some(runtime); + + Self { + db: Some(Arc::new(db)), + } + } + + pub fn postgres(background: Arc) -> Self { + lazy_static! { + static ref LOCK: Mutex<()> = Mutex::new(()); + } + + let _guard = LOCK.lock(); + let mut rng = StdRng::from_entropy(); + let url = format!( + "postgres://postgres@localhost/zed-test-{}", + rng.gen::() + ); + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .unwrap(); + + let mut db = runtime.block_on(async { + sqlx::Postgres::create_database(&url) + .await + .expect("failed to create test db"); + let db = Database::new(&url, 5).await.unwrap(); + let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"); + db.migrate(Path::new(migrations_path), false).await.unwrap(); + db + }); + + db.background = Some(background); + db.runtime = Some(runtime); + + Self { + db: Some(Arc::new(db)), + } + } + + pub fn db(&self) -> &Arc { + self.db.as_ref().unwrap() + } + } + + // TODO: Implement drop + // impl Drop for PostgresTestDb { + // fn drop(&mut self) { + // let db = self.db.take().unwrap(); + // db.teardown(&self.url); + // } + // } +} diff --git a/crates/collab/src/db2/tests.rs b/crates/collab/src/db2/tests.rs new file mode 100644 index 0000000000000000000000000000000000000000..6d8878593829b4903406e7dd5d3163377447b36c --- /dev/null +++ b/crates/collab/src/db2/tests.rs @@ -0,0 +1,808 @@ +use super::*; +use gpui::executor::{Background, Deterministic}; +use std::sync::Arc; + +macro_rules! test_both_dbs { + ($postgres_test_name:ident, $sqlite_test_name:ident, $db:ident, $body:block) => { + #[gpui::test] + async fn $postgres_test_name() { + let test_db = TestDb::postgres(Deterministic::new(0).build_background()); + let $db = test_db.db(); + $body + } + + #[gpui::test] + async fn $sqlite_test_name() { + let test_db = TestDb::sqlite(Deterministic::new(0).build_background()); + let $db = test_db.db(); + $body + } + }; +} + +test_both_dbs!( + test_get_users_by_ids_postgres, + test_get_users_by_ids_sqlite, + db, + { + let mut user_ids = Vec::new(); + for i in 1..=4 { + user_ids.push( + db.create_user( + &format!("user{i}@example.com"), + false, + NewUserParams { + github_login: format!("user{i}"), + github_user_id: i, + invite_count: 0, + }, + ) + .await + .unwrap() + .user_id, + ); + } + + assert_eq!( + db.get_users_by_ids(user_ids.clone()).await.unwrap(), + vec![ + User { + id: user_ids[0], + github_login: "user1".to_string(), + github_user_id: Some(1), + email_address: Some("user1@example.com".to_string()), + admin: false, + ..Default::default() + }, + User { + id: user_ids[1], + github_login: "user2".to_string(), + github_user_id: Some(2), + email_address: Some("user2@example.com".to_string()), + admin: false, + ..Default::default() + }, + User { + id: user_ids[2], + github_login: "user3".to_string(), + github_user_id: Some(3), + email_address: Some("user3@example.com".to_string()), + admin: false, + ..Default::default() + }, + User { + id: user_ids[3], + github_login: "user4".to_string(), + github_user_id: Some(4), + email_address: Some("user4@example.com".to_string()), + admin: false, + ..Default::default() + } + ] + ); + } +); + +test_both_dbs!( + test_get_user_by_github_account_postgres, + test_get_user_by_github_account_sqlite, + db, + { + let user_id1 = db + .create_user( + "user1@example.com", + false, + NewUserParams { + github_login: "login1".into(), + github_user_id: 101, + invite_count: 0, + }, + ) + .await + .unwrap() + .user_id; + let user_id2 = db + .create_user( + "user2@example.com", + false, + NewUserParams { + github_login: "login2".into(), + github_user_id: 102, + invite_count: 0, + }, + ) + .await + .unwrap() + .user_id; + + let user = db + .get_user_by_github_account("login1", None) + .await + .unwrap() + .unwrap(); + assert_eq!(user.id, user_id1); + assert_eq!(&user.github_login, "login1"); + assert_eq!(user.github_user_id, Some(101)); + + assert!(db + .get_user_by_github_account("non-existent-login", None) + .await + .unwrap() + .is_none()); + + let user = db + .get_user_by_github_account("the-new-login2", Some(102)) + .await + .unwrap() + .unwrap(); + assert_eq!(user.id, user_id2); + assert_eq!(&user.github_login, "the-new-login2"); + assert_eq!(user.github_user_id, Some(102)); + } +); + +test_both_dbs!( + test_create_access_tokens_postgres, + test_create_access_tokens_sqlite, + db, + { + let user = db + .create_user( + "u1@example.com", + false, + NewUserParams { + github_login: "u1".into(), + github_user_id: 1, + invite_count: 0, + }, + ) + .await + .unwrap() + .user_id; + + db.create_access_token_hash(user, "h1", 3).await.unwrap(); + db.create_access_token_hash(user, "h2", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h2".to_string(), "h1".to_string()] + ); + + db.create_access_token_hash(user, "h3", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h3".to_string(), "h2".to_string(), "h1".to_string(),] + ); + + db.create_access_token_hash(user, "h4", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h4".to_string(), "h3".to_string(), "h2".to_string(),] + ); + + db.create_access_token_hash(user, "h5", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h5".to_string(), "h4".to_string(), "h3".to_string()] + ); + } +); + +test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { + let mut user_ids = Vec::new(); + for i in 0..3 { + user_ids.push( + db.create_user( + &format!("user{i}@example.com"), + false, + NewUserParams { + github_login: format!("user{i}"), + github_user_id: i, + invite_count: 0, + }, + ) + .await + .unwrap() + .user_id, + ); + } + + let user_1 = user_ids[0]; + let user_2 = user_ids[1]; + let user_3 = user_ids[2]; + + // User starts with no contacts + assert_eq!(db.get_contacts(user_1).await.unwrap(), &[]); + + // User requests a contact. Both users see the pending request. + db.send_contact_request(user_1, user_2).await.unwrap(); + assert!(!db.has_contact(user_1, user_2).await.unwrap()); + assert!(!db.has_contact(user_2, user_1).await.unwrap()); + assert_eq!( + db.get_contacts(user_1).await.unwrap(), + &[Contact::Outgoing { user_id: user_2 }], + ); + assert_eq!( + db.get_contacts(user_2).await.unwrap(), + &[Contact::Incoming { + user_id: user_1, + should_notify: true + }] + ); + + // User 2 dismisses the contact request notification without accepting or rejecting. + // We shouldn't notify them again. + db.dismiss_contact_notification(user_1, user_2) + .await + .unwrap_err(); + db.dismiss_contact_notification(user_2, user_1) + .await + .unwrap(); + assert_eq!( + db.get_contacts(user_2).await.unwrap(), + &[Contact::Incoming { + user_id: user_1, + should_notify: false + }] + ); + + // User can't accept their own contact request + db.respond_to_contact_request(user_1, user_2, true) + .await + .unwrap_err(); + + // User accepts a contact request. Both users see the contact. + db.respond_to_contact_request(user_2, user_1, true) + .await + .unwrap(); + assert_eq!( + db.get_contacts(user_1).await.unwrap(), + &[Contact::Accepted { + user_id: user_2, + should_notify: true, + busy: false, + }], + ); + assert!(db.has_contact(user_1, user_2).await.unwrap()); + assert!(db.has_contact(user_2, user_1).await.unwrap()); + assert_eq!( + db.get_contacts(user_2).await.unwrap(), + &[Contact::Accepted { + user_id: user_1, + should_notify: false, + busy: false, + }] + ); + + // Users cannot re-request existing contacts. + db.send_contact_request(user_1, user_2).await.unwrap_err(); + db.send_contact_request(user_2, user_1).await.unwrap_err(); + + // Users can't dismiss notifications of them accepting other users' requests. + db.dismiss_contact_notification(user_2, user_1) + .await + .unwrap_err(); + assert_eq!( + db.get_contacts(user_1).await.unwrap(), + &[Contact::Accepted { + user_id: user_2, + should_notify: true, + busy: false, + }] + ); + + // Users can dismiss notifications of other users accepting their requests. + db.dismiss_contact_notification(user_1, user_2) + .await + .unwrap(); + assert_eq!( + db.get_contacts(user_1).await.unwrap(), + &[Contact::Accepted { + user_id: user_2, + should_notify: false, + busy: false, + }] + ); + + // Users send each other concurrent contact requests and + // see that they are immediately accepted. + db.send_contact_request(user_1, user_3).await.unwrap(); + db.send_contact_request(user_3, user_1).await.unwrap(); + assert_eq!( + db.get_contacts(user_1).await.unwrap(), + &[ + Contact::Accepted { + user_id: user_2, + should_notify: false, + busy: false, + }, + Contact::Accepted { + user_id: user_3, + should_notify: false, + busy: false, + } + ] + ); + assert_eq!( + db.get_contacts(user_3).await.unwrap(), + &[Contact::Accepted { + user_id: user_1, + should_notify: false, + busy: false, + }], + ); + + // User declines a contact request. Both users see that it is gone. + db.send_contact_request(user_2, user_3).await.unwrap(); + db.respond_to_contact_request(user_3, user_2, false) + .await + .unwrap(); + assert!(!db.has_contact(user_2, user_3).await.unwrap()); + assert!(!db.has_contact(user_3, user_2).await.unwrap()); + assert_eq!( + db.get_contacts(user_2).await.unwrap(), + &[Contact::Accepted { + user_id: user_1, + should_notify: false, + busy: false, + }] + ); + assert_eq!( + db.get_contacts(user_3).await.unwrap(), + &[Contact::Accepted { + user_id: user_1, + should_notify: false, + busy: false, + }], + ); +}); + +test_both_dbs!(test_metrics_id_postgres, test_metrics_id_sqlite, db, { + let NewUserResult { + user_id: user1, + metrics_id: metrics_id1, + .. + } = db + .create_user( + "person1@example.com", + false, + NewUserParams { + github_login: "person1".into(), + github_user_id: 101, + invite_count: 5, + }, + ) + .await + .unwrap(); + let NewUserResult { + user_id: user2, + metrics_id: metrics_id2, + .. + } = db + .create_user( + "person2@example.com", + false, + NewUserParams { + github_login: "person2".into(), + github_user_id: 102, + invite_count: 5, + }, + ) + .await + .unwrap(); + + assert_eq!(db.get_user_metrics_id(user1).await.unwrap(), metrics_id1); + assert_eq!(db.get_user_metrics_id(user2).await.unwrap(), metrics_id2); + assert_eq!(metrics_id1.len(), 36); + assert_eq!(metrics_id2.len(), 36); + assert_ne!(metrics_id1, metrics_id2); +}); + +#[test] +fn test_fuzzy_like_string() { + assert_eq!(DefaultDb::fuzzy_like_string("abcd"), "%a%b%c%d%"); + assert_eq!(DefaultDb::fuzzy_like_string("x y"), "%x%y%"); + assert_eq!(DefaultDb::fuzzy_like_string(" z "), "%z%"); +} + +#[gpui::test] +async fn test_fuzzy_search_users() { + let test_db = PostgresTestDb::new(build_background_executor()); + let db = test_db.db(); + for (i, github_login) in [ + "California", + "colorado", + "oregon", + "washington", + "florida", + "delaware", + "rhode-island", + ] + .into_iter() + .enumerate() + { + db.create_user( + &format!("{github_login}@example.com"), + false, + NewUserParams { + github_login: github_login.into(), + github_user_id: i as i32, + invite_count: 0, + }, + ) + .await + .unwrap(); + } + + assert_eq!( + fuzzy_search_user_names(db, "clr").await, + &["colorado", "California"] + ); + assert_eq!( + fuzzy_search_user_names(db, "ro").await, + &["rhode-island", "colorado", "oregon"], + ); + + async fn fuzzy_search_user_names(db: &Db, query: &str) -> Vec { + db.fuzzy_search_users(query, 10) + .await + .unwrap() + .into_iter() + .map(|user| user.github_login) + .collect::>() + } +} + +#[gpui::test] +async fn test_invite_codes() { + let test_db = PostgresTestDb::new(build_background_executor()); + let db = test_db.db(); + + let NewUserResult { user_id: user1, .. } = db + .create_user( + "user1@example.com", + false, + NewUserParams { + github_login: "user1".into(), + github_user_id: 0, + invite_count: 0, + }, + ) + .await + .unwrap(); + + // Initially, user 1 has no invite code + assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None); + + // Setting invite count to 0 when no code is assigned does not assign a new code + db.set_invite_count_for_user(user1, 0).await.unwrap(); + assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none()); + + // User 1 creates an invite code that can be used twice. + db.set_invite_count_for_user(user1, 2).await.unwrap(); + let (invite_code, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); + assert_eq!(invite_count, 2); + + // User 2 redeems the invite code and becomes a contact of user 1. + let user2_invite = db + .create_invite_from_code(&invite_code, "user2@example.com", Some("user-2-device-id")) + .await + .unwrap(); + let NewUserResult { + user_id: user2, + inviting_user_id, + signup_device_id, + metrics_id, + } = db + .create_user_from_invite( + &user2_invite, + NewUserParams { + github_login: "user2".into(), + github_user_id: 2, + invite_count: 7, + }, + ) + .await + .unwrap() + .unwrap(); + let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); + assert_eq!(invite_count, 1); + assert_eq!(inviting_user_id, Some(user1)); + assert_eq!(signup_device_id.unwrap(), "user-2-device-id"); + assert_eq!(db.get_user_metrics_id(user2).await.unwrap(), metrics_id); + assert_eq!( + db.get_contacts(user1).await.unwrap(), + [Contact::Accepted { + user_id: user2, + should_notify: true, + busy: false, + }] + ); + assert_eq!( + db.get_contacts(user2).await.unwrap(), + [Contact::Accepted { + user_id: user1, + should_notify: false, + busy: false, + }] + ); + assert_eq!( + db.get_invite_code_for_user(user2).await.unwrap().unwrap().1, + 7 + ); + + // User 3 redeems the invite code and becomes a contact of user 1. + let user3_invite = db + .create_invite_from_code(&invite_code, "user3@example.com", None) + .await + .unwrap(); + let NewUserResult { + user_id: user3, + inviting_user_id, + signup_device_id, + .. + } = db + .create_user_from_invite( + &user3_invite, + NewUserParams { + github_login: "user-3".into(), + github_user_id: 3, + invite_count: 3, + }, + ) + .await + .unwrap() + .unwrap(); + let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); + assert_eq!(invite_count, 0); + assert_eq!(inviting_user_id, Some(user1)); + assert!(signup_device_id.is_none()); + assert_eq!( + db.get_contacts(user1).await.unwrap(), + [ + Contact::Accepted { + user_id: user2, + should_notify: true, + busy: false, + }, + Contact::Accepted { + user_id: user3, + should_notify: true, + busy: false, + } + ] + ); + assert_eq!( + db.get_contacts(user3).await.unwrap(), + [Contact::Accepted { + user_id: user1, + should_notify: false, + busy: false, + }] + ); + assert_eq!( + db.get_invite_code_for_user(user3).await.unwrap().unwrap().1, + 3 + ); + + // Trying to reedem the code for the third time results in an error. + db.create_invite_from_code(&invite_code, "user4@example.com", Some("user-4-device-id")) + .await + .unwrap_err(); + + // Invite count can be updated after the code has been created. + db.set_invite_count_for_user(user1, 2).await.unwrap(); + let (latest_code, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); + assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0 + assert_eq!(invite_count, 2); + + // User 4 can now redeem the invite code and becomes a contact of user 1. + let user4_invite = db + .create_invite_from_code(&invite_code, "user4@example.com", Some("user-4-device-id")) + .await + .unwrap(); + let user4 = db + .create_user_from_invite( + &user4_invite, + NewUserParams { + github_login: "user-4".into(), + github_user_id: 4, + invite_count: 5, + }, + ) + .await + .unwrap() + .unwrap() + .user_id; + + let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); + assert_eq!(invite_count, 1); + assert_eq!( + db.get_contacts(user1).await.unwrap(), + [ + Contact::Accepted { + user_id: user2, + should_notify: true, + busy: false, + }, + Contact::Accepted { + user_id: user3, + should_notify: true, + busy: false, + }, + Contact::Accepted { + user_id: user4, + should_notify: true, + busy: false, + } + ] + ); + assert_eq!( + db.get_contacts(user4).await.unwrap(), + [Contact::Accepted { + user_id: user1, + should_notify: false, + busy: false, + }] + ); + assert_eq!( + db.get_invite_code_for_user(user4).await.unwrap().unwrap().1, + 5 + ); + + // An existing user cannot redeem invite codes. + db.create_invite_from_code(&invite_code, "user2@example.com", Some("user-2-device-id")) + .await + .unwrap_err(); + let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); + assert_eq!(invite_count, 1); +} + +#[gpui::test] +async fn test_signups() { + let test_db = PostgresTestDb::new(build_background_executor()); + let db = test_db.db(); + + // people sign up on the waitlist + for i in 0..8 { + db.create_signup(Signup { + email_address: format!("person-{i}@example.com"), + platform_mac: true, + platform_linux: i % 2 == 0, + platform_windows: i % 4 == 0, + editor_features: vec!["speed".into()], + programming_languages: vec!["rust".into(), "c".into()], + device_id: Some(format!("device_id_{i}")), + }) + .await + .unwrap(); + } + + assert_eq!( + db.get_waitlist_summary().await.unwrap(), + WaitlistSummary { + count: 8, + mac_count: 8, + linux_count: 4, + windows_count: 2, + unknown_count: 0, + } + ); + + // retrieve the next batch of signup emails to send + let signups_batch1 = db.get_unsent_invites(3).await.unwrap(); + let addresses = signups_batch1 + .iter() + .map(|s| &s.email_address) + .collect::>(); + assert_eq!( + addresses, + &[ + "person-0@example.com", + "person-1@example.com", + "person-2@example.com" + ] + ); + assert_ne!( + signups_batch1[0].email_confirmation_code, + signups_batch1[1].email_confirmation_code + ); + + // the waitlist isn't updated until we record that the emails + // were successfully sent. + let signups_batch = db.get_unsent_invites(3).await.unwrap(); + assert_eq!(signups_batch, signups_batch1); + + // once the emails go out, we can retrieve the next batch + // of signups. + db.record_sent_invites(&signups_batch1).await.unwrap(); + let signups_batch2 = db.get_unsent_invites(3).await.unwrap(); + let addresses = signups_batch2 + .iter() + .map(|s| &s.email_address) + .collect::>(); + assert_eq!( + addresses, + &[ + "person-3@example.com", + "person-4@example.com", + "person-5@example.com" + ] + ); + + // the sent invites are excluded from the summary. + assert_eq!( + db.get_waitlist_summary().await.unwrap(), + WaitlistSummary { + count: 5, + mac_count: 5, + linux_count: 2, + windows_count: 1, + unknown_count: 0, + } + ); + + // user completes the signup process by providing their + // github account. + let NewUserResult { + user_id, + inviting_user_id, + signup_device_id, + .. + } = db + .create_user_from_invite( + &Invite { + email_address: signups_batch1[0].email_address.clone(), + email_confirmation_code: signups_batch1[0].email_confirmation_code.clone(), + }, + NewUserParams { + github_login: "person-0".into(), + github_user_id: 0, + invite_count: 5, + }, + ) + .await + .unwrap() + .unwrap(); + let user = db.get_user_by_id(user_id).await.unwrap().unwrap(); + assert!(inviting_user_id.is_none()); + assert_eq!(user.github_login, "person-0"); + assert_eq!(user.email_address.as_deref(), Some("person-0@example.com")); + assert_eq!(user.invite_count, 5); + assert_eq!(signup_device_id.unwrap(), "device_id_0"); + + // cannot redeem the same signup again. + assert!(db + .create_user_from_invite( + &Invite { + email_address: signups_batch1[0].email_address.clone(), + email_confirmation_code: signups_batch1[0].email_confirmation_code.clone(), + }, + NewUserParams { + github_login: "some-other-github_account".into(), + github_user_id: 1, + invite_count: 5, + }, + ) + .await + .unwrap() + .is_none()); + + // cannot redeem a signup with the wrong confirmation code. + db.create_user_from_invite( + &Invite { + email_address: signups_batch1[1].email_address.clone(), + email_confirmation_code: "the-wrong-code".to_string(), + }, + NewUserParams { + github_login: "person-1".into(), + github_user_id: 2, + invite_count: 5, + }, + ) + .await + .unwrap_err(); +} + +fn build_background_executor() -> Arc { + Deterministic::new(0).build_background() +} diff --git a/crates/collab/src/db2/user.rs b/crates/collab/src/db2/user.rs new file mode 100644 index 0000000000000000000000000000000000000000..de865db6798584d9afd84ad4e0c55eadbc336b52 --- /dev/null +++ b/crates/collab/src/db2/user.rs @@ -0,0 +1,21 @@ +use super::UserId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "users")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: UserId, + pub github_login: String, + pub github_user_id: Option, + pub email_address: Option, + pub admin: bool, + pub invite_code: Option, + pub invite_count: i32, + pub connected_once: bool, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} From d9a892a423362c8f85157c94255e4b552b25a0e2 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 30 Nov 2022 12:06:25 +0100 Subject: [PATCH 073/109] Make some db tests pass against the new sea-orm implementation --- .../20221109000000_test_schema.sql | 2 +- crates/collab/src/db2.rs | 164 +- crates/collab/src/db2/project.rs | 7 +- crates/collab/src/db2/project_collaborator.rs | 7 +- crates/collab/src/db2/room.rs | 3 +- crates/collab/src/db2/room_participant.rs | 13 +- crates/collab/src/db2/tests.rs | 1453 +++++++++-------- crates/collab/src/db2/user.rs | 3 +- crates/collab/src/db2/worktree.rs | 4 +- 9 files changed, 897 insertions(+), 759 deletions(-) diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 65bf00e74ccfa70cccb1b80bfe7b9142450ce5a1..aeb6b7f720100d6ef72bcc5221d31747de372682 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -8,7 +8,7 @@ CREATE TABLE "users" ( "inviter_id" INTEGER REFERENCES users (id), "connected_once" BOOLEAN NOT NULL DEFAULT false, "created_at" TIMESTAMP NOT NULL DEFAULT now, - "metrics_id" VARCHAR(255), + "metrics_id" TEXT, "github_user_id" INTEGER ); CREATE UNIQUE INDEX "index_users_github_login" ON "users" ("github_login"); diff --git a/crates/collab/src/db2.rs b/crates/collab/src/db2.rs index 765fea315df706cd10cb0905497ea7c82b4ea9cb..47ddf8cd22689a5c7715768cca9ad6e479a277e7 100644 --- a/crates/collab/src/db2.rs +++ b/crates/collab/src/db2.rs @@ -18,6 +18,7 @@ use sea_orm::{ entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr, TransactionTrait, }; +use sea_query::OnConflict; use serde::{Deserialize, Serialize}; use sqlx::migrate::{Migrate, Migration, MigrationSource}; use sqlx::Connection; @@ -42,7 +43,7 @@ pub struct Database { impl Database { pub async fn new(url: &str, max_connections: u32) -> Result { let mut options = ConnectOptions::new(url.into()); - options.min_connections(1).max_connections(max_connections); + options.max_connections(max_connections); Ok(Self { url: url.into(), pool: sea_orm::Database::connect(options).await?, @@ -58,7 +59,7 @@ impl Database { &self, migrations_path: &Path, ignore_checksum_mismatch: bool, - ) -> anyhow::Result> { + ) -> anyhow::Result<(sqlx::AnyConnection, Vec<(Migration, Duration)>)> { let migrations = MigrationSource::resolve(migrations_path) .await .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?; @@ -92,11 +93,45 @@ impl Database { } } - Ok(new_migrations) + Ok((connection, new_migrations)) + } + + pub async fn create_user( + &self, + email_address: &str, + admin: bool, + params: NewUserParams, + ) -> Result { + self.transact(|tx| async { + let user = user::Entity::insert(user::ActiveModel { + email_address: ActiveValue::set(Some(email_address.into())), + github_login: ActiveValue::set(params.github_login.clone()), + github_user_id: ActiveValue::set(Some(params.github_user_id)), + admin: ActiveValue::set(admin), + metrics_id: ActiveValue::set(Uuid::new_v4()), + ..Default::default() + }) + .on_conflict( + OnConflict::column(user::Column::GithubLogin) + .update_column(user::Column::GithubLogin) + .to_owned(), + ) + .exec_with_returning(&tx) + .await?; + + tx.commit().await?; + + Ok(NewUserResult { + user_id: user.id, + metrics_id: user.metrics_id.to_string(), + signup_device_id: None, + inviting_user_id: None, + }) + }) + .await } pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { - let ids = ids.iter().map(|id| id.0).collect::>(); self.transact(|tx| async { let tx = tx; Ok(user::Entity::find() @@ -119,7 +154,7 @@ impl Database { .one(&tx) .await? .ok_or_else(|| anyhow!("could not find participant"))?; - if participant.room_id != room_id.0 { + if participant.room_id != room_id { return Err(anyhow!("shared project on unexpected room"))?; } @@ -156,14 +191,14 @@ impl Database { .await?; let room = self.get_room(room_id, &tx).await?; - self.commit_room_transaction(room_id, tx, (ProjectId(project.id), room)) + self.commit_room_transaction(room_id, tx, (project.id, room)) .await }) .await } async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result { - let db_room = room::Entity::find_by_id(room_id.0) + let db_room = room::Entity::find_by_id(room_id) .one(tx) .await? .ok_or_else(|| anyhow!("could not find room"))?; @@ -184,7 +219,7 @@ impl Database { (Some(0), Some(project_id)) => { Some(proto::participant_location::Variant::SharedProject( proto::participant_location::SharedProject { - id: project_id as u64, + id: project_id.to_proto(), }, )) } @@ -198,7 +233,7 @@ impl Database { participants.insert( answering_connection_id, proto::Participant { - user_id: db_participant.user_id as u64, + user_id: db_participant.user_id.to_proto(), peer_id: answering_connection_id as u32, projects: Default::default(), location: Some(proto::ParticipantLocation { variant: location }), @@ -206,9 +241,9 @@ impl Database { ); } else { pending_participants.push(proto::PendingParticipant { - user_id: db_participant.user_id as u64, - calling_user_id: db_participant.calling_user_id as u64, - initial_project_id: db_participant.initial_project_id.map(|id| id as u64), + user_id: db_participant.user_id.to_proto(), + calling_user_id: db_participant.calling_user_id.to_proto(), + initial_project_id: db_participant.initial_project_id.map(|id| id.to_proto()), }); } } @@ -225,12 +260,12 @@ impl Database { let project = if let Some(project) = participant .projects .iter_mut() - .find(|project| project.id as i32 == db_project.id) + .find(|project| project.id == db_project.id.to_proto()) { project } else { participant.projects.push(proto::ParticipantProject { - id: db_project.id as u64, + id: db_project.id.to_proto(), worktree_root_names: Default::default(), }); participant.projects.last_mut().unwrap() @@ -243,7 +278,7 @@ impl Database { } Ok(proto::Room { - id: db_room.id as u64, + id: db_room.id.to_proto(), live_kit_room: db_room.live_kit_room, participants: participants.into_values().collect(), pending_participants, @@ -393,6 +428,84 @@ macro_rules! id_type { self.0.fmt(f) } } + + impl From<$name> for sea_query::Value { + fn from(value: $name) -> Self { + sea_query::Value::Int(Some(value.0)) + } + } + + impl sea_orm::TryGetable for $name { + fn try_get( + res: &sea_orm::QueryResult, + pre: &str, + col: &str, + ) -> Result { + Ok(Self(i32::try_get(res, pre, col)?)) + } + } + + impl sea_query::ValueType for $name { + fn try_from(v: Value) -> Result { + match v { + Value::TinyInt(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::SmallInt(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::Int(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::BigInt(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::TinyUnsigned(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::SmallUnsigned(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::Unsigned(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::BigUnsigned(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + _ => Err(sea_query::ValueTypeErr), + } + } + + fn type_name() -> String { + stringify!($name).into() + } + + fn array_type() -> sea_query::ArrayType { + sea_query::ArrayType::Int + } + + fn column_type() -> sea_query::ColumnType { + sea_query::ColumnType::Integer(None) + } + } + + impl sea_orm::TryFromU64 for $name { + fn try_from_u64(n: u64) -> Result { + Ok(Self(n.try_into().map_err(|_| { + DbErr::ConvertFromU64(concat!( + "error converting ", + stringify!($name), + " to u64" + )) + })?)) + } + } + + impl sea_query::Nullable for $name { + fn null() -> Value { + Value::Int(None) + } + } }; } @@ -400,6 +513,7 @@ id_type!(UserId); id_type!(RoomId); id_type!(RoomParticipantId); id_type!(ProjectId); +id_type!(ProjectCollaboratorId); id_type!(WorktreeId); #[cfg(test)] @@ -412,17 +526,18 @@ mod test { use lazy_static::lazy_static; use parking_lot::Mutex; use rand::prelude::*; + use sea_orm::ConnectionTrait; use sqlx::migrate::MigrateDatabase; use std::sync::Arc; pub struct TestDb { pub db: Option>, + pub connection: Option, } impl TestDb { pub fn sqlite(background: Arc) -> Self { - let mut rng = StdRng::from_entropy(); - let url = format!("sqlite://file:zed-test-{}?mode=memory", rng.gen::()); + let url = format!("sqlite::memory:"); let runtime = tokio::runtime::Builder::new_current_thread() .enable_io() .enable_time() @@ -431,8 +546,17 @@ mod test { let mut db = runtime.block_on(async { let db = Database::new(&url, 5).await.unwrap(); - let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite"); - db.migrate(migrations_path.as_ref(), false).await.unwrap(); + let sql = include_str!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/migrations.sqlite/20221109000000_test_schema.sql" + )); + db.pool + .execute(sea_orm::Statement::from_string( + db.pool.get_database_backend(), + sql.into(), + )) + .await + .unwrap(); db }); @@ -441,6 +565,7 @@ mod test { Self { db: Some(Arc::new(db)), + connection: None, } } @@ -476,6 +601,7 @@ mod test { Self { db: Some(Arc::new(db)), + connection: None, } } diff --git a/crates/collab/src/db2/project.rs b/crates/collab/src/db2/project.rs index 4ae061683508bc2a1ab2ba580668bb45775f92c6..21ee0b27d1350603f2bd5b7118cd853a49fee512 100644 --- a/crates/collab/src/db2/project.rs +++ b/crates/collab/src/db2/project.rs @@ -1,12 +1,13 @@ +use super::{ProjectId, RoomId, UserId}; use sea_orm::entity::prelude::*; #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] #[sea_orm(table_name = "projects")] pub struct Model { #[sea_orm(primary_key)] - pub id: i32, - pub room_id: i32, - pub host_user_id: i32, + pub id: ProjectId, + pub room_id: RoomId, + pub host_user_id: UserId, pub host_connection_id: i32, } diff --git a/crates/collab/src/db2/project_collaborator.rs b/crates/collab/src/db2/project_collaborator.rs index da567eb2c23e683a1fe7b319978511985e819017..3e572fe5d4fc94029bfa73c91648bfc44800aead 100644 --- a/crates/collab/src/db2/project_collaborator.rs +++ b/crates/collab/src/db2/project_collaborator.rs @@ -1,13 +1,14 @@ +use super::{ProjectCollaboratorId, ProjectId, UserId}; use sea_orm::entity::prelude::*; #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] #[sea_orm(table_name = "project_collaborators")] pub struct Model { #[sea_orm(primary_key)] - pub id: i32, - pub project_id: i32, + pub id: ProjectCollaboratorId, + pub project_id: ProjectId, pub connection_id: i32, - pub user_id: i32, + pub user_id: UserId, pub replica_id: i32, pub is_host: bool, } diff --git a/crates/collab/src/db2/room.rs b/crates/collab/src/db2/room.rs index 18f1d234e54733355715fe46b3b5614065afe680..b57e612d46e32dced2be353e9d7c5bffe6d200bf 100644 --- a/crates/collab/src/db2/room.rs +++ b/crates/collab/src/db2/room.rs @@ -1,10 +1,11 @@ +use super::RoomId; use sea_orm::entity::prelude::*; #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] #[sea_orm(table_name = "room_participants")] pub struct Model { #[sea_orm(primary_key)] - pub id: i32, + pub id: RoomId, pub live_kit_room: String, } diff --git a/crates/collab/src/db2/room_participant.rs b/crates/collab/src/db2/room_participant.rs index c9b7a13e07f53a8cab2b44bef2927dc280abe1c6..4fabfc3068925ae864c31b7c8c9aa8f5f9898ccc 100644 --- a/crates/collab/src/db2/room_participant.rs +++ b/crates/collab/src/db2/room_participant.rs @@ -1,17 +1,18 @@ +use super::{ProjectId, RoomId, RoomParticipantId, UserId}; use sea_orm::entity::prelude::*; #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] #[sea_orm(table_name = "room_participants")] pub struct Model { #[sea_orm(primary_key)] - pub id: i32, - pub room_id: i32, - pub user_id: i32, + pub id: RoomParticipantId, + pub room_id: RoomId, + pub user_id: UserId, pub answering_connection_id: Option, pub location_kind: Option, - pub location_project_id: Option, - pub initial_project_id: Option, - pub calling_user_id: i32, + pub location_project_id: Option, + pub initial_project_id: Option, + pub calling_user_id: UserId, pub calling_connection_id: i32, } diff --git a/crates/collab/src/db2/tests.rs b/crates/collab/src/db2/tests.rs index 6d8878593829b4903406e7dd5d3163377447b36c..a5bac241407f7811860be614eea3b8f1a5cf30f3 100644 --- a/crates/collab/src/db2/tests.rs +++ b/crates/collab/src/db2/tests.rs @@ -26,9 +26,10 @@ test_both_dbs!( db, { let mut user_ids = Vec::new(); + let mut user_metric_ids = Vec::new(); for i in 1..=4 { - user_ids.push( - db.create_user( + let user = db + .create_user( &format!("user{i}@example.com"), false, NewUserParams { @@ -38,9 +39,9 @@ test_both_dbs!( }, ) .await - .unwrap() - .user_id, - ); + .unwrap(); + user_ids.push(user.user_id); + user_metric_ids.push(user.metrics_id); } assert_eq!( @@ -52,6 +53,7 @@ test_both_dbs!( github_user_id: Some(1), email_address: Some("user1@example.com".to_string()), admin: false, + metrics_id: user_metric_ids[0].parse().unwrap(), ..Default::default() }, User { @@ -60,6 +62,7 @@ test_both_dbs!( github_user_id: Some(2), email_address: Some("user2@example.com".to_string()), admin: false, + metrics_id: user_metric_ids[1].parse().unwrap(), ..Default::default() }, User { @@ -68,6 +71,7 @@ test_both_dbs!( github_user_id: Some(3), email_address: Some("user3@example.com".to_string()), admin: false, + metrics_id: user_metric_ids[2].parse().unwrap(), ..Default::default() }, User { @@ -76,6 +80,7 @@ test_both_dbs!( github_user_id: Some(4), email_address: Some("user4@example.com".to_string()), admin: false, + metrics_id: user_metric_ids[3].parse().unwrap(), ..Default::default() } ] @@ -83,725 +88,725 @@ test_both_dbs!( } ); -test_both_dbs!( - test_get_user_by_github_account_postgres, - test_get_user_by_github_account_sqlite, - db, - { - let user_id1 = db - .create_user( - "user1@example.com", - false, - NewUserParams { - github_login: "login1".into(), - github_user_id: 101, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - let user_id2 = db - .create_user( - "user2@example.com", - false, - NewUserParams { - github_login: "login2".into(), - github_user_id: 102, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - - let user = db - .get_user_by_github_account("login1", None) - .await - .unwrap() - .unwrap(); - assert_eq!(user.id, user_id1); - assert_eq!(&user.github_login, "login1"); - assert_eq!(user.github_user_id, Some(101)); - - assert!(db - .get_user_by_github_account("non-existent-login", None) - .await - .unwrap() - .is_none()); - - let user = db - .get_user_by_github_account("the-new-login2", Some(102)) - .await - .unwrap() - .unwrap(); - assert_eq!(user.id, user_id2); - assert_eq!(&user.github_login, "the-new-login2"); - assert_eq!(user.github_user_id, Some(102)); - } -); - -test_both_dbs!( - test_create_access_tokens_postgres, - test_create_access_tokens_sqlite, - db, - { - let user = db - .create_user( - "u1@example.com", - false, - NewUserParams { - github_login: "u1".into(), - github_user_id: 1, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - - db.create_access_token_hash(user, "h1", 3).await.unwrap(); - db.create_access_token_hash(user, "h2", 3).await.unwrap(); - assert_eq!( - db.get_access_token_hashes(user).await.unwrap(), - &["h2".to_string(), "h1".to_string()] - ); - - db.create_access_token_hash(user, "h3", 3).await.unwrap(); - assert_eq!( - db.get_access_token_hashes(user).await.unwrap(), - &["h3".to_string(), "h2".to_string(), "h1".to_string(),] - ); - - db.create_access_token_hash(user, "h4", 3).await.unwrap(); - assert_eq!( - db.get_access_token_hashes(user).await.unwrap(), - &["h4".to_string(), "h3".to_string(), "h2".to_string(),] - ); - - db.create_access_token_hash(user, "h5", 3).await.unwrap(); - assert_eq!( - db.get_access_token_hashes(user).await.unwrap(), - &["h5".to_string(), "h4".to_string(), "h3".to_string()] - ); - } -); - -test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { - let mut user_ids = Vec::new(); - for i in 0..3 { - user_ids.push( - db.create_user( - &format!("user{i}@example.com"), - false, - NewUserParams { - github_login: format!("user{i}"), - github_user_id: i, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id, - ); - } - - let user_1 = user_ids[0]; - let user_2 = user_ids[1]; - let user_3 = user_ids[2]; - - // User starts with no contacts - assert_eq!(db.get_contacts(user_1).await.unwrap(), &[]); - - // User requests a contact. Both users see the pending request. - db.send_contact_request(user_1, user_2).await.unwrap(); - assert!(!db.has_contact(user_1, user_2).await.unwrap()); - assert!(!db.has_contact(user_2, user_1).await.unwrap()); - assert_eq!( - db.get_contacts(user_1).await.unwrap(), - &[Contact::Outgoing { user_id: user_2 }], - ); - assert_eq!( - db.get_contacts(user_2).await.unwrap(), - &[Contact::Incoming { - user_id: user_1, - should_notify: true - }] - ); - - // User 2 dismisses the contact request notification without accepting or rejecting. - // We shouldn't notify them again. - db.dismiss_contact_notification(user_1, user_2) - .await - .unwrap_err(); - db.dismiss_contact_notification(user_2, user_1) - .await - .unwrap(); - assert_eq!( - db.get_contacts(user_2).await.unwrap(), - &[Contact::Incoming { - user_id: user_1, - should_notify: false - }] - ); - - // User can't accept their own contact request - db.respond_to_contact_request(user_1, user_2, true) - .await - .unwrap_err(); - - // User accepts a contact request. Both users see the contact. - db.respond_to_contact_request(user_2, user_1, true) - .await - .unwrap(); - assert_eq!( - db.get_contacts(user_1).await.unwrap(), - &[Contact::Accepted { - user_id: user_2, - should_notify: true, - busy: false, - }], - ); - assert!(db.has_contact(user_1, user_2).await.unwrap()); - assert!(db.has_contact(user_2, user_1).await.unwrap()); - assert_eq!( - db.get_contacts(user_2).await.unwrap(), - &[Contact::Accepted { - user_id: user_1, - should_notify: false, - busy: false, - }] - ); - - // Users cannot re-request existing contacts. - db.send_contact_request(user_1, user_2).await.unwrap_err(); - db.send_contact_request(user_2, user_1).await.unwrap_err(); - - // Users can't dismiss notifications of them accepting other users' requests. - db.dismiss_contact_notification(user_2, user_1) - .await - .unwrap_err(); - assert_eq!( - db.get_contacts(user_1).await.unwrap(), - &[Contact::Accepted { - user_id: user_2, - should_notify: true, - busy: false, - }] - ); - - // Users can dismiss notifications of other users accepting their requests. - db.dismiss_contact_notification(user_1, user_2) - .await - .unwrap(); - assert_eq!( - db.get_contacts(user_1).await.unwrap(), - &[Contact::Accepted { - user_id: user_2, - should_notify: false, - busy: false, - }] - ); - - // Users send each other concurrent contact requests and - // see that they are immediately accepted. - db.send_contact_request(user_1, user_3).await.unwrap(); - db.send_contact_request(user_3, user_1).await.unwrap(); - assert_eq!( - db.get_contacts(user_1).await.unwrap(), - &[ - Contact::Accepted { - user_id: user_2, - should_notify: false, - busy: false, - }, - Contact::Accepted { - user_id: user_3, - should_notify: false, - busy: false, - } - ] - ); - assert_eq!( - db.get_contacts(user_3).await.unwrap(), - &[Contact::Accepted { - user_id: user_1, - should_notify: false, - busy: false, - }], - ); - - // User declines a contact request. Both users see that it is gone. - db.send_contact_request(user_2, user_3).await.unwrap(); - db.respond_to_contact_request(user_3, user_2, false) - .await - .unwrap(); - assert!(!db.has_contact(user_2, user_3).await.unwrap()); - assert!(!db.has_contact(user_3, user_2).await.unwrap()); - assert_eq!( - db.get_contacts(user_2).await.unwrap(), - &[Contact::Accepted { - user_id: user_1, - should_notify: false, - busy: false, - }] - ); - assert_eq!( - db.get_contacts(user_3).await.unwrap(), - &[Contact::Accepted { - user_id: user_1, - should_notify: false, - busy: false, - }], - ); -}); - -test_both_dbs!(test_metrics_id_postgres, test_metrics_id_sqlite, db, { - let NewUserResult { - user_id: user1, - metrics_id: metrics_id1, - .. - } = db - .create_user( - "person1@example.com", - false, - NewUserParams { - github_login: "person1".into(), - github_user_id: 101, - invite_count: 5, - }, - ) - .await - .unwrap(); - let NewUserResult { - user_id: user2, - metrics_id: metrics_id2, - .. - } = db - .create_user( - "person2@example.com", - false, - NewUserParams { - github_login: "person2".into(), - github_user_id: 102, - invite_count: 5, - }, - ) - .await - .unwrap(); - - assert_eq!(db.get_user_metrics_id(user1).await.unwrap(), metrics_id1); - assert_eq!(db.get_user_metrics_id(user2).await.unwrap(), metrics_id2); - assert_eq!(metrics_id1.len(), 36); - assert_eq!(metrics_id2.len(), 36); - assert_ne!(metrics_id1, metrics_id2); -}); - -#[test] -fn test_fuzzy_like_string() { - assert_eq!(DefaultDb::fuzzy_like_string("abcd"), "%a%b%c%d%"); - assert_eq!(DefaultDb::fuzzy_like_string("x y"), "%x%y%"); - assert_eq!(DefaultDb::fuzzy_like_string(" z "), "%z%"); -} - -#[gpui::test] -async fn test_fuzzy_search_users() { - let test_db = PostgresTestDb::new(build_background_executor()); - let db = test_db.db(); - for (i, github_login) in [ - "California", - "colorado", - "oregon", - "washington", - "florida", - "delaware", - "rhode-island", - ] - .into_iter() - .enumerate() - { - db.create_user( - &format!("{github_login}@example.com"), - false, - NewUserParams { - github_login: github_login.into(), - github_user_id: i as i32, - invite_count: 0, - }, - ) - .await - .unwrap(); - } - - assert_eq!( - fuzzy_search_user_names(db, "clr").await, - &["colorado", "California"] - ); - assert_eq!( - fuzzy_search_user_names(db, "ro").await, - &["rhode-island", "colorado", "oregon"], - ); - - async fn fuzzy_search_user_names(db: &Db, query: &str) -> Vec { - db.fuzzy_search_users(query, 10) - .await - .unwrap() - .into_iter() - .map(|user| user.github_login) - .collect::>() - } -} - -#[gpui::test] -async fn test_invite_codes() { - let test_db = PostgresTestDb::new(build_background_executor()); - let db = test_db.db(); - - let NewUserResult { user_id: user1, .. } = db - .create_user( - "user1@example.com", - false, - NewUserParams { - github_login: "user1".into(), - github_user_id: 0, - invite_count: 0, - }, - ) - .await - .unwrap(); - - // Initially, user 1 has no invite code - assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None); - - // Setting invite count to 0 when no code is assigned does not assign a new code - db.set_invite_count_for_user(user1, 0).await.unwrap(); - assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none()); - - // User 1 creates an invite code that can be used twice. - db.set_invite_count_for_user(user1, 2).await.unwrap(); - let (invite_code, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); - assert_eq!(invite_count, 2); - - // User 2 redeems the invite code and becomes a contact of user 1. - let user2_invite = db - .create_invite_from_code(&invite_code, "user2@example.com", Some("user-2-device-id")) - .await - .unwrap(); - let NewUserResult { - user_id: user2, - inviting_user_id, - signup_device_id, - metrics_id, - } = db - .create_user_from_invite( - &user2_invite, - NewUserParams { - github_login: "user2".into(), - github_user_id: 2, - invite_count: 7, - }, - ) - .await - .unwrap() - .unwrap(); - let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); - assert_eq!(invite_count, 1); - assert_eq!(inviting_user_id, Some(user1)); - assert_eq!(signup_device_id.unwrap(), "user-2-device-id"); - assert_eq!(db.get_user_metrics_id(user2).await.unwrap(), metrics_id); - assert_eq!( - db.get_contacts(user1).await.unwrap(), - [Contact::Accepted { - user_id: user2, - should_notify: true, - busy: false, - }] - ); - assert_eq!( - db.get_contacts(user2).await.unwrap(), - [Contact::Accepted { - user_id: user1, - should_notify: false, - busy: false, - }] - ); - assert_eq!( - db.get_invite_code_for_user(user2).await.unwrap().unwrap().1, - 7 - ); - - // User 3 redeems the invite code and becomes a contact of user 1. - let user3_invite = db - .create_invite_from_code(&invite_code, "user3@example.com", None) - .await - .unwrap(); - let NewUserResult { - user_id: user3, - inviting_user_id, - signup_device_id, - .. - } = db - .create_user_from_invite( - &user3_invite, - NewUserParams { - github_login: "user-3".into(), - github_user_id: 3, - invite_count: 3, - }, - ) - .await - .unwrap() - .unwrap(); - let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); - assert_eq!(invite_count, 0); - assert_eq!(inviting_user_id, Some(user1)); - assert!(signup_device_id.is_none()); - assert_eq!( - db.get_contacts(user1).await.unwrap(), - [ - Contact::Accepted { - user_id: user2, - should_notify: true, - busy: false, - }, - Contact::Accepted { - user_id: user3, - should_notify: true, - busy: false, - } - ] - ); - assert_eq!( - db.get_contacts(user3).await.unwrap(), - [Contact::Accepted { - user_id: user1, - should_notify: false, - busy: false, - }] - ); - assert_eq!( - db.get_invite_code_for_user(user3).await.unwrap().unwrap().1, - 3 - ); - - // Trying to reedem the code for the third time results in an error. - db.create_invite_from_code(&invite_code, "user4@example.com", Some("user-4-device-id")) - .await - .unwrap_err(); - - // Invite count can be updated after the code has been created. - db.set_invite_count_for_user(user1, 2).await.unwrap(); - let (latest_code, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); - assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0 - assert_eq!(invite_count, 2); - - // User 4 can now redeem the invite code and becomes a contact of user 1. - let user4_invite = db - .create_invite_from_code(&invite_code, "user4@example.com", Some("user-4-device-id")) - .await - .unwrap(); - let user4 = db - .create_user_from_invite( - &user4_invite, - NewUserParams { - github_login: "user-4".into(), - github_user_id: 4, - invite_count: 5, - }, - ) - .await - .unwrap() - .unwrap() - .user_id; - - let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); - assert_eq!(invite_count, 1); - assert_eq!( - db.get_contacts(user1).await.unwrap(), - [ - Contact::Accepted { - user_id: user2, - should_notify: true, - busy: false, - }, - Contact::Accepted { - user_id: user3, - should_notify: true, - busy: false, - }, - Contact::Accepted { - user_id: user4, - should_notify: true, - busy: false, - } - ] - ); - assert_eq!( - db.get_contacts(user4).await.unwrap(), - [Contact::Accepted { - user_id: user1, - should_notify: false, - busy: false, - }] - ); - assert_eq!( - db.get_invite_code_for_user(user4).await.unwrap().unwrap().1, - 5 - ); - - // An existing user cannot redeem invite codes. - db.create_invite_from_code(&invite_code, "user2@example.com", Some("user-2-device-id")) - .await - .unwrap_err(); - let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); - assert_eq!(invite_count, 1); -} - -#[gpui::test] -async fn test_signups() { - let test_db = PostgresTestDb::new(build_background_executor()); - let db = test_db.db(); - - // people sign up on the waitlist - for i in 0..8 { - db.create_signup(Signup { - email_address: format!("person-{i}@example.com"), - platform_mac: true, - platform_linux: i % 2 == 0, - platform_windows: i % 4 == 0, - editor_features: vec!["speed".into()], - programming_languages: vec!["rust".into(), "c".into()], - device_id: Some(format!("device_id_{i}")), - }) - .await - .unwrap(); - } - - assert_eq!( - db.get_waitlist_summary().await.unwrap(), - WaitlistSummary { - count: 8, - mac_count: 8, - linux_count: 4, - windows_count: 2, - unknown_count: 0, - } - ); - - // retrieve the next batch of signup emails to send - let signups_batch1 = db.get_unsent_invites(3).await.unwrap(); - let addresses = signups_batch1 - .iter() - .map(|s| &s.email_address) - .collect::>(); - assert_eq!( - addresses, - &[ - "person-0@example.com", - "person-1@example.com", - "person-2@example.com" - ] - ); - assert_ne!( - signups_batch1[0].email_confirmation_code, - signups_batch1[1].email_confirmation_code - ); - - // the waitlist isn't updated until we record that the emails - // were successfully sent. - let signups_batch = db.get_unsent_invites(3).await.unwrap(); - assert_eq!(signups_batch, signups_batch1); - - // once the emails go out, we can retrieve the next batch - // of signups. - db.record_sent_invites(&signups_batch1).await.unwrap(); - let signups_batch2 = db.get_unsent_invites(3).await.unwrap(); - let addresses = signups_batch2 - .iter() - .map(|s| &s.email_address) - .collect::>(); - assert_eq!( - addresses, - &[ - "person-3@example.com", - "person-4@example.com", - "person-5@example.com" - ] - ); - - // the sent invites are excluded from the summary. - assert_eq!( - db.get_waitlist_summary().await.unwrap(), - WaitlistSummary { - count: 5, - mac_count: 5, - linux_count: 2, - windows_count: 1, - unknown_count: 0, - } - ); - - // user completes the signup process by providing their - // github account. - let NewUserResult { - user_id, - inviting_user_id, - signup_device_id, - .. - } = db - .create_user_from_invite( - &Invite { - email_address: signups_batch1[0].email_address.clone(), - email_confirmation_code: signups_batch1[0].email_confirmation_code.clone(), - }, - NewUserParams { - github_login: "person-0".into(), - github_user_id: 0, - invite_count: 5, - }, - ) - .await - .unwrap() - .unwrap(); - let user = db.get_user_by_id(user_id).await.unwrap().unwrap(); - assert!(inviting_user_id.is_none()); - assert_eq!(user.github_login, "person-0"); - assert_eq!(user.email_address.as_deref(), Some("person-0@example.com")); - assert_eq!(user.invite_count, 5); - assert_eq!(signup_device_id.unwrap(), "device_id_0"); - - // cannot redeem the same signup again. - assert!(db - .create_user_from_invite( - &Invite { - email_address: signups_batch1[0].email_address.clone(), - email_confirmation_code: signups_batch1[0].email_confirmation_code.clone(), - }, - NewUserParams { - github_login: "some-other-github_account".into(), - github_user_id: 1, - invite_count: 5, - }, - ) - .await - .unwrap() - .is_none()); - - // cannot redeem a signup with the wrong confirmation code. - db.create_user_from_invite( - &Invite { - email_address: signups_batch1[1].email_address.clone(), - email_confirmation_code: "the-wrong-code".to_string(), - }, - NewUserParams { - github_login: "person-1".into(), - github_user_id: 2, - invite_count: 5, - }, - ) - .await - .unwrap_err(); -} +// test_both_dbs!( +// test_get_user_by_github_account_postgres, +// test_get_user_by_github_account_sqlite, +// db, +// { +// let user_id1 = db +// .create_user( +// "user1@example.com", +// false, +// NewUserParams { +// github_login: "login1".into(), +// github_user_id: 101, +// invite_count: 0, +// }, +// ) +// .await +// .unwrap() +// .user_id; +// let user_id2 = db +// .create_user( +// "user2@example.com", +// false, +// NewUserParams { +// github_login: "login2".into(), +// github_user_id: 102, +// invite_count: 0, +// }, +// ) +// .await +// .unwrap() +// .user_id; + +// let user = db +// .get_user_by_github_account("login1", None) +// .await +// .unwrap() +// .unwrap(); +// assert_eq!(user.id, user_id1); +// assert_eq!(&user.github_login, "login1"); +// assert_eq!(user.github_user_id, Some(101)); + +// assert!(db +// .get_user_by_github_account("non-existent-login", None) +// .await +// .unwrap() +// .is_none()); + +// let user = db +// .get_user_by_github_account("the-new-login2", Some(102)) +// .await +// .unwrap() +// .unwrap(); +// assert_eq!(user.id, user_id2); +// assert_eq!(&user.github_login, "the-new-login2"); +// assert_eq!(user.github_user_id, Some(102)); +// } +// ); + +// test_both_dbs!( +// test_create_access_tokens_postgres, +// test_create_access_tokens_sqlite, +// db, +// { +// let user = db +// .create_user( +// "u1@example.com", +// false, +// NewUserParams { +// github_login: "u1".into(), +// github_user_id: 1, +// invite_count: 0, +// }, +// ) +// .await +// .unwrap() +// .user_id; + +// db.create_access_token_hash(user, "h1", 3).await.unwrap(); +// db.create_access_token_hash(user, "h2", 3).await.unwrap(); +// assert_eq!( +// db.get_access_token_hashes(user).await.unwrap(), +// &["h2".to_string(), "h1".to_string()] +// ); + +// db.create_access_token_hash(user, "h3", 3).await.unwrap(); +// assert_eq!( +// db.get_access_token_hashes(user).await.unwrap(), +// &["h3".to_string(), "h2".to_string(), "h1".to_string(),] +// ); + +// db.create_access_token_hash(user, "h4", 3).await.unwrap(); +// assert_eq!( +// db.get_access_token_hashes(user).await.unwrap(), +// &["h4".to_string(), "h3".to_string(), "h2".to_string(),] +// ); + +// db.create_access_token_hash(user, "h5", 3).await.unwrap(); +// assert_eq!( +// db.get_access_token_hashes(user).await.unwrap(), +// &["h5".to_string(), "h4".to_string(), "h3".to_string()] +// ); +// } +// ); + +// test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { +// let mut user_ids = Vec::new(); +// for i in 0..3 { +// user_ids.push( +// db.create_user( +// &format!("user{i}@example.com"), +// false, +// NewUserParams { +// github_login: format!("user{i}"), +// github_user_id: i, +// invite_count: 0, +// }, +// ) +// .await +// .unwrap() +// .user_id, +// ); +// } + +// let user_1 = user_ids[0]; +// let user_2 = user_ids[1]; +// let user_3 = user_ids[2]; + +// // User starts with no contacts +// assert_eq!(db.get_contacts(user_1).await.unwrap(), &[]); + +// // User requests a contact. Both users see the pending request. +// db.send_contact_request(user_1, user_2).await.unwrap(); +// assert!(!db.has_contact(user_1, user_2).await.unwrap()); +// assert!(!db.has_contact(user_2, user_1).await.unwrap()); +// assert_eq!( +// db.get_contacts(user_1).await.unwrap(), +// &[Contact::Outgoing { user_id: user_2 }], +// ); +// assert_eq!( +// db.get_contacts(user_2).await.unwrap(), +// &[Contact::Incoming { +// user_id: user_1, +// should_notify: true +// }] +// ); + +// // User 2 dismisses the contact request notification without accepting or rejecting. +// // We shouldn't notify them again. +// db.dismiss_contact_notification(user_1, user_2) +// .await +// .unwrap_err(); +// db.dismiss_contact_notification(user_2, user_1) +// .await +// .unwrap(); +// assert_eq!( +// db.get_contacts(user_2).await.unwrap(), +// &[Contact::Incoming { +// user_id: user_1, +// should_notify: false +// }] +// ); + +// // User can't accept their own contact request +// db.respond_to_contact_request(user_1, user_2, true) +// .await +// .unwrap_err(); + +// // User accepts a contact request. Both users see the contact. +// db.respond_to_contact_request(user_2, user_1, true) +// .await +// .unwrap(); +// assert_eq!( +// db.get_contacts(user_1).await.unwrap(), +// &[Contact::Accepted { +// user_id: user_2, +// should_notify: true, +// busy: false, +// }], +// ); +// assert!(db.has_contact(user_1, user_2).await.unwrap()); +// assert!(db.has_contact(user_2, user_1).await.unwrap()); +// assert_eq!( +// db.get_contacts(user_2).await.unwrap(), +// &[Contact::Accepted { +// user_id: user_1, +// should_notify: false, +// busy: false, +// }] +// ); + +// // Users cannot re-request existing contacts. +// db.send_contact_request(user_1, user_2).await.unwrap_err(); +// db.send_contact_request(user_2, user_1).await.unwrap_err(); + +// // Users can't dismiss notifications of them accepting other users' requests. +// db.dismiss_contact_notification(user_2, user_1) +// .await +// .unwrap_err(); +// assert_eq!( +// db.get_contacts(user_1).await.unwrap(), +// &[Contact::Accepted { +// user_id: user_2, +// should_notify: true, +// busy: false, +// }] +// ); + +// // Users can dismiss notifications of other users accepting their requests. +// db.dismiss_contact_notification(user_1, user_2) +// .await +// .unwrap(); +// assert_eq!( +// db.get_contacts(user_1).await.unwrap(), +// &[Contact::Accepted { +// user_id: user_2, +// should_notify: false, +// busy: false, +// }] +// ); + +// // Users send each other concurrent contact requests and +// // see that they are immediately accepted. +// db.send_contact_request(user_1, user_3).await.unwrap(); +// db.send_contact_request(user_3, user_1).await.unwrap(); +// assert_eq!( +// db.get_contacts(user_1).await.unwrap(), +// &[ +// Contact::Accepted { +// user_id: user_2, +// should_notify: false, +// busy: false, +// }, +// Contact::Accepted { +// user_id: user_3, +// should_notify: false, +// busy: false, +// } +// ] +// ); +// assert_eq!( +// db.get_contacts(user_3).await.unwrap(), +// &[Contact::Accepted { +// user_id: user_1, +// should_notify: false, +// busy: false, +// }], +// ); + +// // User declines a contact request. Both users see that it is gone. +// db.send_contact_request(user_2, user_3).await.unwrap(); +// db.respond_to_contact_request(user_3, user_2, false) +// .await +// .unwrap(); +// assert!(!db.has_contact(user_2, user_3).await.unwrap()); +// assert!(!db.has_contact(user_3, user_2).await.unwrap()); +// assert_eq!( +// db.get_contacts(user_2).await.unwrap(), +// &[Contact::Accepted { +// user_id: user_1, +// should_notify: false, +// busy: false, +// }] +// ); +// assert_eq!( +// db.get_contacts(user_3).await.unwrap(), +// &[Contact::Accepted { +// user_id: user_1, +// should_notify: false, +// busy: false, +// }], +// ); +// }); + +// test_both_dbs!(test_metrics_id_postgres, test_metrics_id_sqlite, db, { +// let NewUserResult { +// user_id: user1, +// metrics_id: metrics_id1, +// .. +// } = db +// .create_user( +// "person1@example.com", +// false, +// NewUserParams { +// github_login: "person1".into(), +// github_user_id: 101, +// invite_count: 5, +// }, +// ) +// .await +// .unwrap(); +// let NewUserResult { +// user_id: user2, +// metrics_id: metrics_id2, +// .. +// } = db +// .create_user( +// "person2@example.com", +// false, +// NewUserParams { +// github_login: "person2".into(), +// github_user_id: 102, +// invite_count: 5, +// }, +// ) +// .await +// .unwrap(); + +// assert_eq!(db.get_user_metrics_id(user1).await.unwrap(), metrics_id1); +// assert_eq!(db.get_user_metrics_id(user2).await.unwrap(), metrics_id2); +// assert_eq!(metrics_id1.len(), 36); +// assert_eq!(metrics_id2.len(), 36); +// assert_ne!(metrics_id1, metrics_id2); +// }); + +// #[test] +// fn test_fuzzy_like_string() { +// assert_eq!(DefaultDb::fuzzy_like_string("abcd"), "%a%b%c%d%"); +// assert_eq!(DefaultDb::fuzzy_like_string("x y"), "%x%y%"); +// assert_eq!(DefaultDb::fuzzy_like_string(" z "), "%z%"); +// } + +// #[gpui::test] +// async fn test_fuzzy_search_users() { +// let test_db = PostgresTestDb::new(build_background_executor()); +// let db = test_db.db(); +// for (i, github_login) in [ +// "California", +// "colorado", +// "oregon", +// "washington", +// "florida", +// "delaware", +// "rhode-island", +// ] +// .into_iter() +// .enumerate() +// { +// db.create_user( +// &format!("{github_login}@example.com"), +// false, +// NewUserParams { +// github_login: github_login.into(), +// github_user_id: i as i32, +// invite_count: 0, +// }, +// ) +// .await +// .unwrap(); +// } + +// assert_eq!( +// fuzzy_search_user_names(db, "clr").await, +// &["colorado", "California"] +// ); +// assert_eq!( +// fuzzy_search_user_names(db, "ro").await, +// &["rhode-island", "colorado", "oregon"], +// ); + +// async fn fuzzy_search_user_names(db: &Db, query: &str) -> Vec { +// db.fuzzy_search_users(query, 10) +// .await +// .unwrap() +// .into_iter() +// .map(|user| user.github_login) +// .collect::>() +// } +// } + +// #[gpui::test] +// async fn test_invite_codes() { +// let test_db = PostgresTestDb::new(build_background_executor()); +// let db = test_db.db(); + +// let NewUserResult { user_id: user1, .. } = db +// .create_user( +// "user1@example.com", +// false, +// NewUserParams { +// github_login: "user1".into(), +// github_user_id: 0, +// invite_count: 0, +// }, +// ) +// .await +// .unwrap(); + +// // Initially, user 1 has no invite code +// assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None); + +// // Setting invite count to 0 when no code is assigned does not assign a new code +// db.set_invite_count_for_user(user1, 0).await.unwrap(); +// assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none()); + +// // User 1 creates an invite code that can be used twice. +// db.set_invite_count_for_user(user1, 2).await.unwrap(); +// let (invite_code, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); +// assert_eq!(invite_count, 2); + +// // User 2 redeems the invite code and becomes a contact of user 1. +// let user2_invite = db +// .create_invite_from_code(&invite_code, "user2@example.com", Some("user-2-device-id")) +// .await +// .unwrap(); +// let NewUserResult { +// user_id: user2, +// inviting_user_id, +// signup_device_id, +// metrics_id, +// } = db +// .create_user_from_invite( +// &user2_invite, +// NewUserParams { +// github_login: "user2".into(), +// github_user_id: 2, +// invite_count: 7, +// }, +// ) +// .await +// .unwrap() +// .unwrap(); +// let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); +// assert_eq!(invite_count, 1); +// assert_eq!(inviting_user_id, Some(user1)); +// assert_eq!(signup_device_id.unwrap(), "user-2-device-id"); +// assert_eq!(db.get_user_metrics_id(user2).await.unwrap(), metrics_id); +// assert_eq!( +// db.get_contacts(user1).await.unwrap(), +// [Contact::Accepted { +// user_id: user2, +// should_notify: true, +// busy: false, +// }] +// ); +// assert_eq!( +// db.get_contacts(user2).await.unwrap(), +// [Contact::Accepted { +// user_id: user1, +// should_notify: false, +// busy: false, +// }] +// ); +// assert_eq!( +// db.get_invite_code_for_user(user2).await.unwrap().unwrap().1, +// 7 +// ); + +// // User 3 redeems the invite code and becomes a contact of user 1. +// let user3_invite = db +// .create_invite_from_code(&invite_code, "user3@example.com", None) +// .await +// .unwrap(); +// let NewUserResult { +// user_id: user3, +// inviting_user_id, +// signup_device_id, +// .. +// } = db +// .create_user_from_invite( +// &user3_invite, +// NewUserParams { +// github_login: "user-3".into(), +// github_user_id: 3, +// invite_count: 3, +// }, +// ) +// .await +// .unwrap() +// .unwrap(); +// let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); +// assert_eq!(invite_count, 0); +// assert_eq!(inviting_user_id, Some(user1)); +// assert!(signup_device_id.is_none()); +// assert_eq!( +// db.get_contacts(user1).await.unwrap(), +// [ +// Contact::Accepted { +// user_id: user2, +// should_notify: true, +// busy: false, +// }, +// Contact::Accepted { +// user_id: user3, +// should_notify: true, +// busy: false, +// } +// ] +// ); +// assert_eq!( +// db.get_contacts(user3).await.unwrap(), +// [Contact::Accepted { +// user_id: user1, +// should_notify: false, +// busy: false, +// }] +// ); +// assert_eq!( +// db.get_invite_code_for_user(user3).await.unwrap().unwrap().1, +// 3 +// ); + +// // Trying to reedem the code for the third time results in an error. +// db.create_invite_from_code(&invite_code, "user4@example.com", Some("user-4-device-id")) +// .await +// .unwrap_err(); + +// // Invite count can be updated after the code has been created. +// db.set_invite_count_for_user(user1, 2).await.unwrap(); +// let (latest_code, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); +// assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0 +// assert_eq!(invite_count, 2); + +// // User 4 can now redeem the invite code and becomes a contact of user 1. +// let user4_invite = db +// .create_invite_from_code(&invite_code, "user4@example.com", Some("user-4-device-id")) +// .await +// .unwrap(); +// let user4 = db +// .create_user_from_invite( +// &user4_invite, +// NewUserParams { +// github_login: "user-4".into(), +// github_user_id: 4, +// invite_count: 5, +// }, +// ) +// .await +// .unwrap() +// .unwrap() +// .user_id; + +// let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); +// assert_eq!(invite_count, 1); +// assert_eq!( +// db.get_contacts(user1).await.unwrap(), +// [ +// Contact::Accepted { +// user_id: user2, +// should_notify: true, +// busy: false, +// }, +// Contact::Accepted { +// user_id: user3, +// should_notify: true, +// busy: false, +// }, +// Contact::Accepted { +// user_id: user4, +// should_notify: true, +// busy: false, +// } +// ] +// ); +// assert_eq!( +// db.get_contacts(user4).await.unwrap(), +// [Contact::Accepted { +// user_id: user1, +// should_notify: false, +// busy: false, +// }] +// ); +// assert_eq!( +// db.get_invite_code_for_user(user4).await.unwrap().unwrap().1, +// 5 +// ); + +// // An existing user cannot redeem invite codes. +// db.create_invite_from_code(&invite_code, "user2@example.com", Some("user-2-device-id")) +// .await +// .unwrap_err(); +// let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); +// assert_eq!(invite_count, 1); +// } + +// #[gpui::test] +// async fn test_signups() { +// let test_db = PostgresTestDb::new(build_background_executor()); +// let db = test_db.db(); + +// // people sign up on the waitlist +// for i in 0..8 { +// db.create_signup(Signup { +// email_address: format!("person-{i}@example.com"), +// platform_mac: true, +// platform_linux: i % 2 == 0, +// platform_windows: i % 4 == 0, +// editor_features: vec!["speed".into()], +// programming_languages: vec!["rust".into(), "c".into()], +// device_id: Some(format!("device_id_{i}")), +// }) +// .await +// .unwrap(); +// } + +// assert_eq!( +// db.get_waitlist_summary().await.unwrap(), +// WaitlistSummary { +// count: 8, +// mac_count: 8, +// linux_count: 4, +// windows_count: 2, +// unknown_count: 0, +// } +// ); + +// // retrieve the next batch of signup emails to send +// let signups_batch1 = db.get_unsent_invites(3).await.unwrap(); +// let addresses = signups_batch1 +// .iter() +// .map(|s| &s.email_address) +// .collect::>(); +// assert_eq!( +// addresses, +// &[ +// "person-0@example.com", +// "person-1@example.com", +// "person-2@example.com" +// ] +// ); +// assert_ne!( +// signups_batch1[0].email_confirmation_code, +// signups_batch1[1].email_confirmation_code +// ); + +// // the waitlist isn't updated until we record that the emails +// // were successfully sent. +// let signups_batch = db.get_unsent_invites(3).await.unwrap(); +// assert_eq!(signups_batch, signups_batch1); + +// // once the emails go out, we can retrieve the next batch +// // of signups. +// db.record_sent_invites(&signups_batch1).await.unwrap(); +// let signups_batch2 = db.get_unsent_invites(3).await.unwrap(); +// let addresses = signups_batch2 +// .iter() +// .map(|s| &s.email_address) +// .collect::>(); +// assert_eq!( +// addresses, +// &[ +// "person-3@example.com", +// "person-4@example.com", +// "person-5@example.com" +// ] +// ); + +// // the sent invites are excluded from the summary. +// assert_eq!( +// db.get_waitlist_summary().await.unwrap(), +// WaitlistSummary { +// count: 5, +// mac_count: 5, +// linux_count: 2, +// windows_count: 1, +// unknown_count: 0, +// } +// ); + +// // user completes the signup process by providing their +// // github account. +// let NewUserResult { +// user_id, +// inviting_user_id, +// signup_device_id, +// .. +// } = db +// .create_user_from_invite( +// &Invite { +// email_address: signups_batch1[0].email_address.clone(), +// email_confirmation_code: signups_batch1[0].email_confirmation_code.clone(), +// }, +// NewUserParams { +// github_login: "person-0".into(), +// github_user_id: 0, +// invite_count: 5, +// }, +// ) +// .await +// .unwrap() +// .unwrap(); +// let user = db.get_user_by_id(user_id).await.unwrap().unwrap(); +// assert!(inviting_user_id.is_none()); +// assert_eq!(user.github_login, "person-0"); +// assert_eq!(user.email_address.as_deref(), Some("person-0@example.com")); +// assert_eq!(user.invite_count, 5); +// assert_eq!(signup_device_id.unwrap(), "device_id_0"); + +// // cannot redeem the same signup again. +// assert!(db +// .create_user_from_invite( +// &Invite { +// email_address: signups_batch1[0].email_address.clone(), +// email_confirmation_code: signups_batch1[0].email_confirmation_code.clone(), +// }, +// NewUserParams { +// github_login: "some-other-github_account".into(), +// github_user_id: 1, +// invite_count: 5, +// }, +// ) +// .await +// .unwrap() +// .is_none()); + +// // cannot redeem a signup with the wrong confirmation code. +// db.create_user_from_invite( +// &Invite { +// email_address: signups_batch1[1].email_address.clone(), +// email_confirmation_code: "the-wrong-code".to_string(), +// }, +// NewUserParams { +// github_login: "person-1".into(), +// github_user_id: 2, +// invite_count: 5, +// }, +// ) +// .await +// .unwrap_err(); +// } fn build_background_executor() -> Arc { Deterministic::new(0).build_background() diff --git a/crates/collab/src/db2/user.rs b/crates/collab/src/db2/user.rs index de865db6798584d9afd84ad4e0c55eadbc336b52..a0e21f98110d3bf83935d2e360be3bacf25de367 100644 --- a/crates/collab/src/db2/user.rs +++ b/crates/collab/src/db2/user.rs @@ -1,7 +1,7 @@ use super::UserId; use sea_orm::entity::prelude::*; -#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)] #[sea_orm(table_name = "users")] pub struct Model { #[sea_orm(primary_key)] @@ -13,6 +13,7 @@ pub struct Model { pub invite_code: Option, pub invite_count: i32, pub connected_once: bool, + pub metrics_id: Uuid, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/crates/collab/src/db2/worktree.rs b/crates/collab/src/db2/worktree.rs index 3a630fcfc9d3002206580243129745f3a022fa44..3c6f7c0c1d62d274b3c2bc95e150678037117e96 100644 --- a/crates/collab/src/db2/worktree.rs +++ b/crates/collab/src/db2/worktree.rs @@ -1,12 +1,14 @@ use sea_orm::entity::prelude::*; +use super::ProjectId; + #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] #[sea_orm(table_name = "worktrees")] pub struct Model { #[sea_orm(primary_key)] pub id: i32, #[sea_orm(primary_key)] - pub project_id: i32, + pub project_id: ProjectId, pub abs_path: String, pub root_name: String, pub visible: bool, From 9e59056e7fdf7886ba31461543b5942089cca3fa Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 30 Nov 2022 14:18:46 +0100 Subject: [PATCH 074/109] Implement `db2::Database::get_user_by_github_account` --- crates/collab/src/db2.rs | 97 ++++++++++++++++++++++------ crates/collab/src/db2/tests.rs | 114 ++++++++++++++++----------------- 2 files changed, 136 insertions(+), 75 deletions(-) diff --git a/crates/collab/src/db2.rs b/crates/collab/src/db2.rs index 47ddf8cd22689a5c7715768cca9ad6e479a277e7..1d50437a9cfa5fb4ac0e62abf97df3ca4d0195e5 100644 --- a/crates/collab/src/db2.rs +++ b/crates/collab/src/db2.rs @@ -13,11 +13,11 @@ use collections::HashMap; use dashmap::DashMap; use futures::StreamExt; use rpc::{proto, ConnectionId}; -use sea_orm::ActiveValue; use sea_orm::{ entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr, TransactionTrait, }; +use sea_orm::{ActiveValue, IntoActiveModel}; use sea_query::OnConflict; use serde::{Deserialize, Serialize}; use sqlx::migrate::{Migrate, Migration, MigrationSource}; @@ -31,7 +31,7 @@ use tokio::sync::{Mutex, OwnedMutexGuard}; pub use user::Model as User; pub struct Database { - url: String, + options: ConnectOptions, pool: DatabaseConnection, rooms: DashMap>>, #[cfg(test)] @@ -41,11 +41,9 @@ pub struct Database { } impl Database { - pub async fn new(url: &str, max_connections: u32) -> Result { - let mut options = ConnectOptions::new(url.into()); - options.max_connections(max_connections); + pub async fn new(options: ConnectOptions) -> Result { Ok(Self { - url: url.into(), + options: options.clone(), pool: sea_orm::Database::connect(options).await?, rooms: DashMap::with_capacity(16384), #[cfg(test)] @@ -59,12 +57,12 @@ impl Database { &self, migrations_path: &Path, ignore_checksum_mismatch: bool, - ) -> anyhow::Result<(sqlx::AnyConnection, Vec<(Migration, Duration)>)> { + ) -> anyhow::Result> { let migrations = MigrationSource::resolve(migrations_path) .await .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?; - let mut connection = sqlx::AnyConnection::connect(&self.url).await?; + let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?; connection.ensure_migrations_table().await?; let applied_migrations: HashMap<_, _> = connection @@ -93,7 +91,7 @@ impl Database { } } - Ok((connection, new_migrations)) + Ok(new_migrations) } pub async fn create_user( @@ -142,6 +140,43 @@ impl Database { .await } + pub async fn get_user_by_github_account( + &self, + github_login: &str, + github_user_id: Option, + ) -> Result> { + self.transact(|tx| async { + let tx = tx; + if let Some(github_user_id) = github_user_id { + if let Some(user_by_github_user_id) = user::Entity::find() + .filter(user::Column::GithubUserId.eq(github_user_id)) + .one(&tx) + .await? + { + let mut user_by_github_user_id = user_by_github_user_id.into_active_model(); + user_by_github_user_id.github_login = ActiveValue::set(github_login.into()); + Ok(Some(user_by_github_user_id.update(&tx).await?)) + } else if let Some(user_by_github_login) = user::Entity::find() + .filter(user::Column::GithubLogin.eq(github_login)) + .one(&tx) + .await? + { + let mut user_by_github_login = user_by_github_login.into_active_model(); + user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id)); + Ok(Some(user_by_github_login.update(&tx).await?)) + } else { + Ok(None) + } + } else { + Ok(user::Entity::find() + .filter(user::Column::GithubLogin.eq(github_login)) + .one(&tx) + .await?) + } + }) + .await + } + pub async fn share_project( &self, room_id: RoomId, @@ -545,7 +580,9 @@ mod test { .unwrap(); let mut db = runtime.block_on(async { - let db = Database::new(&url, 5).await.unwrap(); + let mut options = ConnectOptions::new(url); + options.max_connections(5); + let db = Database::new(options).await.unwrap(); let sql = include_str!(concat!( env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite/20221109000000_test_schema.sql" @@ -590,7 +627,11 @@ mod test { sqlx::Postgres::create_database(&url) .await .expect("failed to create test db"); - let db = Database::new(&url, 5).await.unwrap(); + let mut options = ConnectOptions::new(url); + options + .max_connections(5) + .idle_timeout(Duration::from_secs(0)); + let db = Database::new(options).await.unwrap(); let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"); db.migrate(Path::new(migrations_path), false).await.unwrap(); db @@ -610,11 +651,31 @@ mod test { } } - // TODO: Implement drop - // impl Drop for PostgresTestDb { - // fn drop(&mut self) { - // let db = self.db.take().unwrap(); - // db.teardown(&self.url); - // } - // } + impl Drop for TestDb { + fn drop(&mut self) { + let db = self.db.take().unwrap(); + if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() { + db.runtime.as_ref().unwrap().block_on(async { + use util::ResultExt; + let query = " + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE + pg_stat_activity.datname = current_database() AND + pid <> pg_backend_pid(); + "; + db.pool + .execute(sea_orm::Statement::from_string( + db.pool.get_database_backend(), + query.into(), + )) + .await + .log_err(); + sqlx::Postgres::drop_database(db.options.get_url()) + .await + .log_err(); + }) + } + } + } } diff --git a/crates/collab/src/db2/tests.rs b/crates/collab/src/db2/tests.rs index a5bac241407f7811860be614eea3b8f1a5cf30f3..60d3fa64b03c0e8579f48f08e5ac8807f756057d 100644 --- a/crates/collab/src/db2/tests.rs +++ b/crates/collab/src/db2/tests.rs @@ -88,63 +88,63 @@ test_both_dbs!( } ); -// test_both_dbs!( -// test_get_user_by_github_account_postgres, -// test_get_user_by_github_account_sqlite, -// db, -// { -// let user_id1 = db -// .create_user( -// "user1@example.com", -// false, -// NewUserParams { -// github_login: "login1".into(), -// github_user_id: 101, -// invite_count: 0, -// }, -// ) -// .await -// .unwrap() -// .user_id; -// let user_id2 = db -// .create_user( -// "user2@example.com", -// false, -// NewUserParams { -// github_login: "login2".into(), -// github_user_id: 102, -// invite_count: 0, -// }, -// ) -// .await -// .unwrap() -// .user_id; - -// let user = db -// .get_user_by_github_account("login1", None) -// .await -// .unwrap() -// .unwrap(); -// assert_eq!(user.id, user_id1); -// assert_eq!(&user.github_login, "login1"); -// assert_eq!(user.github_user_id, Some(101)); - -// assert!(db -// .get_user_by_github_account("non-existent-login", None) -// .await -// .unwrap() -// .is_none()); - -// let user = db -// .get_user_by_github_account("the-new-login2", Some(102)) -// .await -// .unwrap() -// .unwrap(); -// assert_eq!(user.id, user_id2); -// assert_eq!(&user.github_login, "the-new-login2"); -// assert_eq!(user.github_user_id, Some(102)); -// } -// ); +test_both_dbs!( + test_get_user_by_github_account_postgres, + test_get_user_by_github_account_sqlite, + db, + { + let user_id1 = db + .create_user( + "user1@example.com", + false, + NewUserParams { + github_login: "login1".into(), + github_user_id: 101, + invite_count: 0, + }, + ) + .await + .unwrap() + .user_id; + let user_id2 = db + .create_user( + "user2@example.com", + false, + NewUserParams { + github_login: "login2".into(), + github_user_id: 102, + invite_count: 0, + }, + ) + .await + .unwrap() + .user_id; + + let user = db + .get_user_by_github_account("login1", None) + .await + .unwrap() + .unwrap(); + assert_eq!(user.id, user_id1); + assert_eq!(&user.github_login, "login1"); + assert_eq!(user.github_user_id, Some(101)); + + assert!(db + .get_user_by_github_account("non-existent-login", None) + .await + .unwrap() + .is_none()); + + let user = db + .get_user_by_github_account("the-new-login2", Some(102)) + .await + .unwrap() + .unwrap(); + assert_eq!(user.id, user_id2); + assert_eq!(&user.github_login, "the-new-login2"); + assert_eq!(user.github_user_id, Some(102)); + } +); // test_both_dbs!( // test_create_access_tokens_postgres, From 2e24d128dba01f05055725fb43d2c51d89ce7138 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 30 Nov 2022 14:47:03 +0100 Subject: [PATCH 075/109] Implement access tokens using sea-orm --- crates/collab/src/db2.rs | 73 +++++++++++++++++++++++- crates/collab/src/db2/access_token.rs | 29 ++++++++++ crates/collab/src/db2/tests.rs | 82 +++++++++++++-------------- crates/collab/src/db2/user.rs | 11 +++- 4 files changed, 151 insertions(+), 44 deletions(-) create mode 100644 crates/collab/src/db2/access_token.rs diff --git a/crates/collab/src/db2.rs b/crates/collab/src/db2.rs index 1d50437a9cfa5fb4ac0e62abf97df3ca4d0195e5..e2a03931d81d9a62e84db2bc581d16eea2abb08f 100644 --- a/crates/collab/src/db2.rs +++ b/crates/collab/src/db2.rs @@ -1,3 +1,4 @@ +mod access_token; mod project; mod project_collaborator; mod room; @@ -17,8 +18,8 @@ use sea_orm::{ entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr, TransactionTrait, }; -use sea_orm::{ActiveValue, IntoActiveModel}; -use sea_query::OnConflict; +use sea_orm::{ActiveValue, ConnectionTrait, IntoActiveModel, QueryOrder, QuerySelect}; +use sea_query::{OnConflict, Query}; use serde::{Deserialize, Serialize}; use sqlx::migrate::{Migrate, Migration, MigrationSource}; use sqlx::Connection; @@ -336,6 +337,63 @@ impl Database { }) } + pub async fn create_access_token_hash( + &self, + user_id: UserId, + access_token_hash: &str, + max_access_token_count: usize, + ) -> Result<()> { + self.transact(|tx| async { + let tx = tx; + + access_token::ActiveModel { + user_id: ActiveValue::set(user_id), + hash: ActiveValue::set(access_token_hash.into()), + ..Default::default() + } + .insert(&tx) + .await?; + + access_token::Entity::delete_many() + .filter( + access_token::Column::Id.in_subquery( + Query::select() + .column(access_token::Column::Id) + .from(access_token::Entity) + .and_where(access_token::Column::UserId.eq(user_id)) + .order_by(access_token::Column::Id, sea_orm::Order::Desc) + .limit(10000) + .offset(max_access_token_count as u64) + .to_owned(), + ), + ) + .exec(&tx) + .await?; + tx.commit().await?; + Ok(()) + }) + .await + } + + pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result> { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryAs { + Hash, + } + + self.transact(|tx| async move { + Ok(access_token::Entity::find() + .select_only() + .column(access_token::Column::Hash) + .filter(access_token::Column::UserId.eq(user_id)) + .order_by_desc(access_token::Column::Id) + .into_values::<_, QueryAs>() + .all(&tx) + .await?) + }) + .await + } + async fn transact(&self, f: F) -> Result where F: Send + Fn(DatabaseTransaction) -> Fut, @@ -344,6 +402,16 @@ impl Database { let body = async { loop { let tx = self.pool.begin().await?; + + // In Postgres, serializable transactions are opt-in + if let sea_orm::DatabaseBackend::Postgres = self.pool.get_database_backend() { + tx.execute(sea_orm::Statement::from_string( + sea_orm::DatabaseBackend::Postgres, + "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(), + )) + .await?; + } + match f(tx).await { Ok(result) => return Ok(result), Err(error) => match error { @@ -544,6 +612,7 @@ macro_rules! id_type { }; } +id_type!(AccessTokenId); id_type!(UserId); id_type!(RoomId); id_type!(RoomParticipantId); diff --git a/crates/collab/src/db2/access_token.rs b/crates/collab/src/db2/access_token.rs new file mode 100644 index 0000000000000000000000000000000000000000..f5caa4843dd43bff501ac87870e367a960dd25ac --- /dev/null +++ b/crates/collab/src/db2/access_token.rs @@ -0,0 +1,29 @@ +use super::{AccessTokenId, UserId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "access_tokens")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: AccessTokenId, + pub user_id: UserId, + pub hash: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::UserId", + to = "super::user::Column::Id" + )] + User, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::User.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db2/tests.rs b/crates/collab/src/db2/tests.rs index 60d3fa64b03c0e8579f48f08e5ac8807f756057d..e26ffee7a8830cd4743b2bcccd1aa0a59bdf2b30 100644 --- a/crates/collab/src/db2/tests.rs +++ b/crates/collab/src/db2/tests.rs @@ -146,51 +146,51 @@ test_both_dbs!( } ); -// test_both_dbs!( -// test_create_access_tokens_postgres, -// test_create_access_tokens_sqlite, -// db, -// { -// let user = db -// .create_user( -// "u1@example.com", -// false, -// NewUserParams { -// github_login: "u1".into(), -// github_user_id: 1, -// invite_count: 0, -// }, -// ) -// .await -// .unwrap() -// .user_id; +test_both_dbs!( + test_create_access_tokens_postgres, + test_create_access_tokens_sqlite, + db, + { + let user = db + .create_user( + "u1@example.com", + false, + NewUserParams { + github_login: "u1".into(), + github_user_id: 1, + invite_count: 0, + }, + ) + .await + .unwrap() + .user_id; -// db.create_access_token_hash(user, "h1", 3).await.unwrap(); -// db.create_access_token_hash(user, "h2", 3).await.unwrap(); -// assert_eq!( -// db.get_access_token_hashes(user).await.unwrap(), -// &["h2".to_string(), "h1".to_string()] -// ); + db.create_access_token_hash(user, "h1", 3).await.unwrap(); + db.create_access_token_hash(user, "h2", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h2".to_string(), "h1".to_string()] + ); -// db.create_access_token_hash(user, "h3", 3).await.unwrap(); -// assert_eq!( -// db.get_access_token_hashes(user).await.unwrap(), -// &["h3".to_string(), "h2".to_string(), "h1".to_string(),] -// ); + db.create_access_token_hash(user, "h3", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h3".to_string(), "h2".to_string(), "h1".to_string(),] + ); -// db.create_access_token_hash(user, "h4", 3).await.unwrap(); -// assert_eq!( -// db.get_access_token_hashes(user).await.unwrap(), -// &["h4".to_string(), "h3".to_string(), "h2".to_string(),] -// ); + db.create_access_token_hash(user, "h4", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h4".to_string(), "h3".to_string(), "h2".to_string(),] + ); -// db.create_access_token_hash(user, "h5", 3).await.unwrap(); -// assert_eq!( -// db.get_access_token_hashes(user).await.unwrap(), -// &["h5".to_string(), "h4".to_string(), "h3".to_string()] -// ); -// } -// ); + db.create_access_token_hash(user, "h5", 3).await.unwrap(); + assert_eq!( + db.get_access_token_hashes(user).await.unwrap(), + &["h5".to_string(), "h4".to_string(), "h3".to_string()] + ); + } +); // test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { // let mut user_ids = Vec::new(); diff --git a/crates/collab/src/db2/user.rs b/crates/collab/src/db2/user.rs index a0e21f98110d3bf83935d2e360be3bacf25de367..5e8a48457167cf3c6abdceb457d8fd4362d81773 100644 --- a/crates/collab/src/db2/user.rs +++ b/crates/collab/src/db2/user.rs @@ -17,6 +17,15 @@ pub struct Model { } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation {} +pub enum Relation { + #[sea_orm(has_many = "super::access_token::Entity")] + AccessToken, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::AccessToken.def() + } +} impl ActiveModelBehavior for ActiveModel {} From 04d553d4d32e3c4dea2c608607ca015b230a535b Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 30 Nov 2022 15:06:04 +0100 Subject: [PATCH 076/109] Implement `db2::Database::get_user_metrics_id` --- crates/collab/src/db2.rs | 19 ++++++++ crates/collab/src/db2/tests.rs | 80 +++++++++++++++++----------------- 2 files changed, 59 insertions(+), 40 deletions(-) diff --git a/crates/collab/src/db2.rs b/crates/collab/src/db2.rs index e2a03931d81d9a62e84db2bc581d16eea2abb08f..5c5157d2aa7f9e8b5ca9239159b5ad9bb115a5fa 100644 --- a/crates/collab/src/db2.rs +++ b/crates/collab/src/db2.rs @@ -178,6 +178,25 @@ impl Database { .await } + pub async fn get_user_metrics_id(&self, id: UserId) -> Result { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryAs { + MetricsId, + } + + self.transact(|tx| async move { + let metrics_id: Uuid = user::Entity::find_by_id(id) + .select_only() + .column(user::Column::MetricsId) + .into_values::<_, QueryAs>() + .one(&tx) + .await? + .ok_or_else(|| anyhow!("could not find user"))?; + Ok(metrics_id.to_string()) + }) + .await + } + pub async fn share_project( &self, room_id: RoomId, diff --git a/crates/collab/src/db2/tests.rs b/crates/collab/src/db2/tests.rs index e26ffee7a8830cd4743b2bcccd1aa0a59bdf2b30..c66e2fa4061d2f402861ef58c138da2d1f5fbf51 100644 --- a/crates/collab/src/db2/tests.rs +++ b/crates/collab/src/db2/tests.rs @@ -361,46 +361,46 @@ test_both_dbs!( // ); // }); -// test_both_dbs!(test_metrics_id_postgres, test_metrics_id_sqlite, db, { -// let NewUserResult { -// user_id: user1, -// metrics_id: metrics_id1, -// .. -// } = db -// .create_user( -// "person1@example.com", -// false, -// NewUserParams { -// github_login: "person1".into(), -// github_user_id: 101, -// invite_count: 5, -// }, -// ) -// .await -// .unwrap(); -// let NewUserResult { -// user_id: user2, -// metrics_id: metrics_id2, -// .. -// } = db -// .create_user( -// "person2@example.com", -// false, -// NewUserParams { -// github_login: "person2".into(), -// github_user_id: 102, -// invite_count: 5, -// }, -// ) -// .await -// .unwrap(); - -// assert_eq!(db.get_user_metrics_id(user1).await.unwrap(), metrics_id1); -// assert_eq!(db.get_user_metrics_id(user2).await.unwrap(), metrics_id2); -// assert_eq!(metrics_id1.len(), 36); -// assert_eq!(metrics_id2.len(), 36); -// assert_ne!(metrics_id1, metrics_id2); -// }); +test_both_dbs!(test_metrics_id_postgres, test_metrics_id_sqlite, db, { + let NewUserResult { + user_id: user1, + metrics_id: metrics_id1, + .. + } = db + .create_user( + "person1@example.com", + false, + NewUserParams { + github_login: "person1".into(), + github_user_id: 101, + invite_count: 5, + }, + ) + .await + .unwrap(); + let NewUserResult { + user_id: user2, + metrics_id: metrics_id2, + .. + } = db + .create_user( + "person2@example.com", + false, + NewUserParams { + github_login: "person2".into(), + github_user_id: 102, + invite_count: 5, + }, + ) + .await + .unwrap(); + + assert_eq!(db.get_user_metrics_id(user1).await.unwrap(), metrics_id1); + assert_eq!(db.get_user_metrics_id(user2).await.unwrap(), metrics_id2); + assert_eq!(metrics_id1.len(), 36); + assert_eq!(metrics_id2.len(), 36); + assert_ne!(metrics_id1, metrics_id2); +}); // #[test] // fn test_fuzzy_like_string() { From d1a44b889edd96fd61e4ba1ca712c80f50d45ee9 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 30 Nov 2022 17:36:25 +0100 Subject: [PATCH 077/109] Implement contacts using sea-orm Co-Authored-By: Nathan Sobo --- crates/collab/src/db2.rs | 298 ++++++++++++++++++- crates/collab/src/db2/contact.rs | 58 ++++ crates/collab/src/db2/room_participant.rs | 12 + crates/collab/src/db2/tests.rs | 332 +++++++++++----------- crates/collab/src/db2/user.rs | 8 + 5 files changed, 540 insertions(+), 168 deletions(-) create mode 100644 crates/collab/src/db2/contact.rs diff --git a/crates/collab/src/db2.rs b/crates/collab/src/db2.rs index 5c5157d2aa7f9e8b5ca9239159b5ad9bb115a5fa..35a45acedf0f91bc6c31ac4fc6993b2896f1ebf7 100644 --- a/crates/collab/src/db2.rs +++ b/crates/collab/src/db2.rs @@ -1,4 +1,5 @@ mod access_token; +mod contact; mod project; mod project_collaborator; mod room; @@ -18,8 +19,11 @@ use sea_orm::{ entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr, TransactionTrait, }; -use sea_orm::{ActiveValue, ConnectionTrait, IntoActiveModel, QueryOrder, QuerySelect}; -use sea_query::{OnConflict, Query}; +use sea_orm::{ + ActiveValue, ConnectionTrait, FromQueryResult, IntoActiveModel, JoinType, QueryOrder, + QuerySelect, +}; +use sea_query::{Alias, Expr, OnConflict, Query}; use serde::{Deserialize, Serialize}; use sqlx::migrate::{Migrate, Migration, MigrationSource}; use sqlx::Connection; @@ -29,6 +33,7 @@ use std::time::Duration; use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc}; use tokio::sync::{Mutex, OwnedMutexGuard}; +pub use contact::Contact; pub use user::Model as User; pub struct Database { @@ -95,6 +100,8 @@ impl Database { Ok(new_migrations) } + // users + pub async fn create_user( &self, email_address: &str, @@ -197,6 +204,292 @@ impl Database { .await } + // contacts + + pub async fn get_contacts(&self, user_id: UserId) -> Result> { + #[derive(Debug, FromQueryResult)] + struct ContactWithUserBusyStatuses { + user_id_a: UserId, + user_id_b: UserId, + a_to_b: bool, + accepted: bool, + should_notify: bool, + user_a_busy: bool, + user_b_busy: bool, + } + + self.transact(|tx| async move { + let user_a_participant = Alias::new("user_a_participant"); + let user_b_participant = Alias::new("user_b_participant"); + let mut db_contacts = contact::Entity::find() + .column_as( + Expr::tbl(user_a_participant.clone(), room_participant::Column::Id) + .is_not_null(), + "user_a_busy", + ) + .column_as( + Expr::tbl(user_b_participant.clone(), room_participant::Column::Id) + .is_not_null(), + "user_b_busy", + ) + .filter( + contact::Column::UserIdA + .eq(user_id) + .or(contact::Column::UserIdB.eq(user_id)), + ) + .join_as( + JoinType::LeftJoin, + contact::Relation::UserARoomParticipant.def(), + user_a_participant, + ) + .join_as( + JoinType::LeftJoin, + contact::Relation::UserBRoomParticipant.def(), + user_b_participant, + ) + .into_model::() + .stream(&tx) + .await?; + + let mut contacts = Vec::new(); + while let Some(db_contact) = db_contacts.next().await { + let db_contact = db_contact?; + if db_contact.user_id_a == user_id { + if db_contact.accepted { + contacts.push(Contact::Accepted { + user_id: db_contact.user_id_b, + should_notify: db_contact.should_notify && db_contact.a_to_b, + busy: db_contact.user_b_busy, + }); + } else if db_contact.a_to_b { + contacts.push(Contact::Outgoing { + user_id: db_contact.user_id_b, + }) + } else { + contacts.push(Contact::Incoming { + user_id: db_contact.user_id_b, + should_notify: db_contact.should_notify, + }); + } + } else if db_contact.accepted { + contacts.push(Contact::Accepted { + user_id: db_contact.user_id_a, + should_notify: db_contact.should_notify && !db_contact.a_to_b, + busy: db_contact.user_a_busy, + }); + } else if db_contact.a_to_b { + contacts.push(Contact::Incoming { + user_id: db_contact.user_id_a, + should_notify: db_contact.should_notify, + }); + } else { + contacts.push(Contact::Outgoing { + user_id: db_contact.user_id_a, + }); + } + } + + contacts.sort_unstable_by_key(|contact| contact.user_id()); + + Ok(contacts) + }) + .await + } + + pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result { + self.transact(|tx| async move { + let (id_a, id_b) = if user_id_1 < user_id_2 { + (user_id_1, user_id_2) + } else { + (user_id_2, user_id_1) + }; + + Ok(contact::Entity::find() + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)) + .and(contact::Column::Accepted.eq(true)), + ) + .one(&tx) + .await? + .is_some()) + }) + .await + } + + pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { + self.transact(|mut tx| async move { + let (id_a, id_b, a_to_b) = if sender_id < receiver_id { + (sender_id, receiver_id, true) + } else { + (receiver_id, sender_id, false) + }; + + let rows_affected = contact::Entity::insert(contact::ActiveModel { + user_id_a: ActiveValue::set(id_a), + user_id_b: ActiveValue::set(id_b), + a_to_b: ActiveValue::set(a_to_b), + accepted: ActiveValue::set(false), + should_notify: ActiveValue::set(true), + ..Default::default() + }) + .on_conflict( + OnConflict::columns([contact::Column::UserIdA, contact::Column::UserIdB]) + .values([ + (contact::Column::Accepted, true.into()), + (contact::Column::ShouldNotify, false.into()), + ]) + .action_and_where( + contact::Column::Accepted.eq(false).and( + contact::Column::AToB + .eq(a_to_b) + .and(contact::Column::UserIdA.eq(id_b)) + .or(contact::Column::AToB + .ne(a_to_b) + .and(contact::Column::UserIdA.eq(id_a))), + ), + ) + .to_owned(), + ) + .exec_without_returning(&tx) + .await?; + + if rows_affected == 1 { + tx.commit().await?; + Ok(()) + } else { + Err(anyhow!("contact already requested"))? + } + }) + .await + } + + pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> { + self.transact(|mut tx| async move { + // let (id_a, id_b) = if responder_id < requester_id { + // (responder_id, requester_id) + // } else { + // (requester_id, responder_id) + // }; + // let query = " + // DELETE FROM contacts + // WHERE user_id_a = $1 AND user_id_b = $2; + // "; + // let result = sqlx::query(query) + // .bind(id_a.0) + // .bind(id_b.0) + // .execute(&mut tx) + // .await?; + + // if result.rows_affected() == 1 { + // tx.commit().await?; + // Ok(()) + // } else { + // Err(anyhow!("no such contact"))? + // } + todo!() + }) + .await + } + + pub async fn dismiss_contact_notification( + &self, + user_id: UserId, + contact_user_id: UserId, + ) -> Result<()> { + self.transact(|tx| async move { + let (id_a, id_b, a_to_b) = if user_id < contact_user_id { + (user_id, contact_user_id, true) + } else { + (contact_user_id, user_id, false) + }; + + let result = contact::Entity::update_many() + .set(contact::ActiveModel { + should_notify: ActiveValue::set(false), + ..Default::default() + }) + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)) + .and( + contact::Column::AToB + .eq(a_to_b) + .and(contact::Column::Accepted.eq(true)) + .or(contact::Column::AToB + .ne(a_to_b) + .and(contact::Column::Accepted.eq(false))), + ), + ) + .exec(&tx) + .await?; + if result.rows_affected == 0 { + Err(anyhow!("no such contact request"))? + } else { + tx.commit().await?; + Ok(()) + } + }) + .await + } + + pub async fn respond_to_contact_request( + &self, + responder_id: UserId, + requester_id: UserId, + accept: bool, + ) -> Result<()> { + self.transact(|tx| async move { + let (id_a, id_b, a_to_b) = if responder_id < requester_id { + (responder_id, requester_id, false) + } else { + (requester_id, responder_id, true) + }; + let rows_affected = if accept { + let result = contact::Entity::update_many() + .set(contact::ActiveModel { + accepted: ActiveValue::set(true), + should_notify: ActiveValue::set(true), + ..Default::default() + }) + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)) + .and(contact::Column::AToB.eq(a_to_b)), + ) + .exec(&tx) + .await?; + result.rows_affected + } else { + let result = contact::Entity::delete_many() + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)) + .and(contact::Column::AToB.eq(a_to_b)) + .and(contact::Column::Accepted.eq(false)), + ) + .exec(&tx) + .await?; + + result.rows_affected + }; + + if rows_affected == 1 { + tx.commit().await?; + Ok(()) + } else { + Err(anyhow!("no such contact request"))? + } + }) + .await + } + + // projects + pub async fn share_project( &self, room_id: RoomId, @@ -632,6 +925,7 @@ macro_rules! id_type { } id_type!(AccessTokenId); +id_type!(ContactId); id_type!(UserId); id_type!(RoomId); id_type!(RoomParticipantId); diff --git a/crates/collab/src/db2/contact.rs b/crates/collab/src/db2/contact.rs new file mode 100644 index 0000000000000000000000000000000000000000..c39d6643b3a4066eb159e8dc87f692d1d5ca3c3c --- /dev/null +++ b/crates/collab/src/db2/contact.rs @@ -0,0 +1,58 @@ +use super::{ContactId, UserId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "contacts")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: ContactId, + pub user_id_a: UserId, + pub user_id_b: UserId, + pub a_to_b: bool, + pub should_notify: bool, + pub accepted: bool, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::room_participant::Entity", + from = "Column::UserIdA", + to = "super::room_participant::Column::UserId" + )] + UserARoomParticipant, + #[sea_orm( + belongs_to = "super::room_participant::Entity", + from = "Column::UserIdB", + to = "super::room_participant::Column::UserId" + )] + UserBRoomParticipant, +} + +impl ActiveModelBehavior for ActiveModel {} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum Contact { + Accepted { + user_id: UserId, + should_notify: bool, + busy: bool, + }, + Outgoing { + user_id: UserId, + }, + Incoming { + user_id: UserId, + should_notify: bool, + }, +} + +impl Contact { + pub fn user_id(&self) -> UserId { + match self { + Contact::Accepted { user_id, .. } => *user_id, + Contact::Outgoing { user_id } => *user_id, + Contact::Incoming { user_id, .. } => *user_id, + } + } +} diff --git a/crates/collab/src/db2/room_participant.rs b/crates/collab/src/db2/room_participant.rs index 4fabfc3068925ae864c31b7c8c9aa8f5f9898ccc..c7c804581b07be6825bbc27b44227d8da4a6b26a 100644 --- a/crates/collab/src/db2/room_participant.rs +++ b/crates/collab/src/db2/room_participant.rs @@ -18,6 +18,12 @@ pub struct Model { #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] pub enum Relation { + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::UserId", + to = "super::user::Column::Id" + )] + User, #[sea_orm( belongs_to = "super::room::Entity", from = "Column::RoomId", @@ -26,6 +32,12 @@ pub enum Relation { Room, } +impl Related for Entity { + fn to() -> RelationDef { + Relation::User.def() + } +} + impl Related for Entity { fn to() -> RelationDef { Relation::Room.def() diff --git a/crates/collab/src/db2/tests.rs b/crates/collab/src/db2/tests.rs index c66e2fa4061d2f402861ef58c138da2d1f5fbf51..1aeb80202500aacbd9a7f087585107303867fe4d 100644 --- a/crates/collab/src/db2/tests.rs +++ b/crates/collab/src/db2/tests.rs @@ -192,174 +192,174 @@ test_both_dbs!( } ); -// test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { -// let mut user_ids = Vec::new(); -// for i in 0..3 { -// user_ids.push( -// db.create_user( -// &format!("user{i}@example.com"), -// false, -// NewUserParams { -// github_login: format!("user{i}"), -// github_user_id: i, -// invite_count: 0, -// }, -// ) -// .await -// .unwrap() -// .user_id, -// ); -// } - -// let user_1 = user_ids[0]; -// let user_2 = user_ids[1]; -// let user_3 = user_ids[2]; - -// // User starts with no contacts -// assert_eq!(db.get_contacts(user_1).await.unwrap(), &[]); - -// // User requests a contact. Both users see the pending request. -// db.send_contact_request(user_1, user_2).await.unwrap(); -// assert!(!db.has_contact(user_1, user_2).await.unwrap()); -// assert!(!db.has_contact(user_2, user_1).await.unwrap()); -// assert_eq!( -// db.get_contacts(user_1).await.unwrap(), -// &[Contact::Outgoing { user_id: user_2 }], -// ); -// assert_eq!( -// db.get_contacts(user_2).await.unwrap(), -// &[Contact::Incoming { -// user_id: user_1, -// should_notify: true -// }] -// ); - -// // User 2 dismisses the contact request notification without accepting or rejecting. -// // We shouldn't notify them again. -// db.dismiss_contact_notification(user_1, user_2) -// .await -// .unwrap_err(); -// db.dismiss_contact_notification(user_2, user_1) -// .await -// .unwrap(); -// assert_eq!( -// db.get_contacts(user_2).await.unwrap(), -// &[Contact::Incoming { -// user_id: user_1, -// should_notify: false -// }] -// ); - -// // User can't accept their own contact request -// db.respond_to_contact_request(user_1, user_2, true) -// .await -// .unwrap_err(); - -// // User accepts a contact request. Both users see the contact. -// db.respond_to_contact_request(user_2, user_1, true) -// .await -// .unwrap(); -// assert_eq!( -// db.get_contacts(user_1).await.unwrap(), -// &[Contact::Accepted { -// user_id: user_2, -// should_notify: true, -// busy: false, -// }], -// ); -// assert!(db.has_contact(user_1, user_2).await.unwrap()); -// assert!(db.has_contact(user_2, user_1).await.unwrap()); -// assert_eq!( -// db.get_contacts(user_2).await.unwrap(), -// &[Contact::Accepted { -// user_id: user_1, -// should_notify: false, -// busy: false, -// }] -// ); - -// // Users cannot re-request existing contacts. -// db.send_contact_request(user_1, user_2).await.unwrap_err(); -// db.send_contact_request(user_2, user_1).await.unwrap_err(); - -// // Users can't dismiss notifications of them accepting other users' requests. -// db.dismiss_contact_notification(user_2, user_1) -// .await -// .unwrap_err(); -// assert_eq!( -// db.get_contacts(user_1).await.unwrap(), -// &[Contact::Accepted { -// user_id: user_2, -// should_notify: true, -// busy: false, -// }] -// ); - -// // Users can dismiss notifications of other users accepting their requests. -// db.dismiss_contact_notification(user_1, user_2) -// .await -// .unwrap(); -// assert_eq!( -// db.get_contacts(user_1).await.unwrap(), -// &[Contact::Accepted { -// user_id: user_2, -// should_notify: false, -// busy: false, -// }] -// ); +test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { + let mut user_ids = Vec::new(); + for i in 0..3 { + user_ids.push( + db.create_user( + &format!("user{i}@example.com"), + false, + NewUserParams { + github_login: format!("user{i}"), + github_user_id: i, + invite_count: 0, + }, + ) + .await + .unwrap() + .user_id, + ); + } -// // Users send each other concurrent contact requests and -// // see that they are immediately accepted. -// db.send_contact_request(user_1, user_3).await.unwrap(); -// db.send_contact_request(user_3, user_1).await.unwrap(); -// assert_eq!( -// db.get_contacts(user_1).await.unwrap(), -// &[ -// Contact::Accepted { -// user_id: user_2, -// should_notify: false, -// busy: false, -// }, -// Contact::Accepted { -// user_id: user_3, -// should_notify: false, -// busy: false, -// } -// ] -// ); -// assert_eq!( -// db.get_contacts(user_3).await.unwrap(), -// &[Contact::Accepted { -// user_id: user_1, -// should_notify: false, -// busy: false, -// }], -// ); + let user_1 = user_ids[0]; + let user_2 = user_ids[1]; + let user_3 = user_ids[2]; + + // User starts with no contacts + assert_eq!(db.get_contacts(user_1).await.unwrap(), &[]); + + // User requests a contact. Both users see the pending request. + db.send_contact_request(user_1, user_2).await.unwrap(); + assert!(!db.has_contact(user_1, user_2).await.unwrap()); + assert!(!db.has_contact(user_2, user_1).await.unwrap()); + assert_eq!( + db.get_contacts(user_1).await.unwrap(), + &[Contact::Outgoing { user_id: user_2 }], + ); + assert_eq!( + db.get_contacts(user_2).await.unwrap(), + &[Contact::Incoming { + user_id: user_1, + should_notify: true + }] + ); + + // User 2 dismisses the contact request notification without accepting or rejecting. + // We shouldn't notify them again. + db.dismiss_contact_notification(user_1, user_2) + .await + .unwrap_err(); + db.dismiss_contact_notification(user_2, user_1) + .await + .unwrap(); + assert_eq!( + db.get_contacts(user_2).await.unwrap(), + &[Contact::Incoming { + user_id: user_1, + should_notify: false + }] + ); + + // User can't accept their own contact request + db.respond_to_contact_request(user_1, user_2, true) + .await + .unwrap_err(); -// // User declines a contact request. Both users see that it is gone. -// db.send_contact_request(user_2, user_3).await.unwrap(); -// db.respond_to_contact_request(user_3, user_2, false) -// .await -// .unwrap(); -// assert!(!db.has_contact(user_2, user_3).await.unwrap()); -// assert!(!db.has_contact(user_3, user_2).await.unwrap()); -// assert_eq!( -// db.get_contacts(user_2).await.unwrap(), -// &[Contact::Accepted { -// user_id: user_1, -// should_notify: false, -// busy: false, -// }] -// ); -// assert_eq!( -// db.get_contacts(user_3).await.unwrap(), -// &[Contact::Accepted { -// user_id: user_1, -// should_notify: false, -// busy: false, -// }], -// ); -// }); + // User accepts a contact request. Both users see the contact. + db.respond_to_contact_request(user_2, user_1, true) + .await + .unwrap(); + assert_eq!( + db.get_contacts(user_1).await.unwrap(), + &[Contact::Accepted { + user_id: user_2, + should_notify: true, + busy: false, + }], + ); + assert!(db.has_contact(user_1, user_2).await.unwrap()); + assert!(db.has_contact(user_2, user_1).await.unwrap()); + assert_eq!( + db.get_contacts(user_2).await.unwrap(), + &[Contact::Accepted { + user_id: user_1, + should_notify: false, + busy: false, + }] + ); + + // Users cannot re-request existing contacts. + db.send_contact_request(user_1, user_2).await.unwrap_err(); + db.send_contact_request(user_2, user_1).await.unwrap_err(); + + // Users can't dismiss notifications of them accepting other users' requests. + db.dismiss_contact_notification(user_2, user_1) + .await + .unwrap_err(); + assert_eq!( + db.get_contacts(user_1).await.unwrap(), + &[Contact::Accepted { + user_id: user_2, + should_notify: true, + busy: false, + }] + ); + + // Users can dismiss notifications of other users accepting their requests. + db.dismiss_contact_notification(user_1, user_2) + .await + .unwrap(); + assert_eq!( + db.get_contacts(user_1).await.unwrap(), + &[Contact::Accepted { + user_id: user_2, + should_notify: false, + busy: false, + }] + ); + + // Users send each other concurrent contact requests and + // see that they are immediately accepted. + db.send_contact_request(user_1, user_3).await.unwrap(); + db.send_contact_request(user_3, user_1).await.unwrap(); + assert_eq!( + db.get_contacts(user_1).await.unwrap(), + &[ + Contact::Accepted { + user_id: user_2, + should_notify: false, + busy: false, + }, + Contact::Accepted { + user_id: user_3, + should_notify: false, + busy: false, + } + ] + ); + assert_eq!( + db.get_contacts(user_3).await.unwrap(), + &[Contact::Accepted { + user_id: user_1, + should_notify: false, + busy: false, + }], + ); + + // User declines a contact request. Both users see that it is gone. + db.send_contact_request(user_2, user_3).await.unwrap(); + db.respond_to_contact_request(user_3, user_2, false) + .await + .unwrap(); + assert!(!db.has_contact(user_2, user_3).await.unwrap()); + assert!(!db.has_contact(user_3, user_2).await.unwrap()); + assert_eq!( + db.get_contacts(user_2).await.unwrap(), + &[Contact::Accepted { + user_id: user_1, + should_notify: false, + busy: false, + }] + ); + assert_eq!( + db.get_contacts(user_3).await.unwrap(), + &[Contact::Accepted { + user_id: user_1, + should_notify: false, + busy: false, + }], + ); +}); test_both_dbs!(test_metrics_id_postgres, test_metrics_id_sqlite, db, { let NewUserResult { diff --git a/crates/collab/src/db2/user.rs b/crates/collab/src/db2/user.rs index 5e8a48457167cf3c6abdceb457d8fd4362d81773..f6bac9dc77d8dd92ce9353019a610a76a83528ae 100644 --- a/crates/collab/src/db2/user.rs +++ b/crates/collab/src/db2/user.rs @@ -20,6 +20,8 @@ pub struct Model { pub enum Relation { #[sea_orm(has_many = "super::access_token::Entity")] AccessToken, + #[sea_orm(has_one = "super::room_participant::Entity")] + RoomParticipant, } impl Related for Entity { @@ -28,4 +30,10 @@ impl Related for Entity { } } +impl Related for Entity { + fn to() -> RelationDef { + Relation::RoomParticipant.def() + } +} + impl ActiveModelBehavior for ActiveModel {} From 4c04d512dbedd1abaa1e60cc1e4d86f2ed6fb87b Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 30 Nov 2022 17:39:17 +0100 Subject: [PATCH 078/109] Implement `db2::Database::remove_contact` --- crates/collab/src/db2.rs | 57 ++++++++++++++++++++-------------- crates/collab/src/db2/tests.rs | 12 +++---- 2 files changed, 40 insertions(+), 29 deletions(-) diff --git a/crates/collab/src/db2.rs b/crates/collab/src/db2.rs index 35a45acedf0f91bc6c31ac4fc6993b2896f1ebf7..2e6b349497770939c135dd6b9b0808cb3570c543 100644 --- a/crates/collab/src/db2.rs +++ b/crates/collab/src/db2.rs @@ -366,29 +366,28 @@ impl Database { } pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> { - self.transact(|mut tx| async move { - // let (id_a, id_b) = if responder_id < requester_id { - // (responder_id, requester_id) - // } else { - // (requester_id, responder_id) - // }; - // let query = " - // DELETE FROM contacts - // WHERE user_id_a = $1 AND user_id_b = $2; - // "; - // let result = sqlx::query(query) - // .bind(id_a.0) - // .bind(id_b.0) - // .execute(&mut tx) - // .await?; - - // if result.rows_affected() == 1 { - // tx.commit().await?; - // Ok(()) - // } else { - // Err(anyhow!("no such contact"))? - // } - todo!() + self.transact(|tx| async move { + let (id_a, id_b) = if responder_id < requester_id { + (responder_id, requester_id) + } else { + (requester_id, responder_id) + }; + + let result = contact::Entity::delete_many() + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)), + ) + .exec(&tx) + .await?; + + if result.rows_affected == 1 { + tx.commit().await?; + Ok(()) + } else { + Err(anyhow!("no such contact"))? + } }) .await } @@ -488,6 +487,18 @@ impl Database { .await } + pub fn fuzzy_like_string(string: &str) -> String { + let mut result = String::with_capacity(string.len() * 2 + 1); + for c in string.chars() { + if c.is_alphanumeric() { + result.push('%'); + result.push(c); + } + } + result.push('%'); + result + } + // projects pub async fn share_project( diff --git a/crates/collab/src/db2/tests.rs b/crates/collab/src/db2/tests.rs index 1aeb80202500aacbd9a7f087585107303867fe4d..45715a925e44137ef2444c50ba8dcc7c43f23763 100644 --- a/crates/collab/src/db2/tests.rs +++ b/crates/collab/src/db2/tests.rs @@ -402,12 +402,12 @@ test_both_dbs!(test_metrics_id_postgres, test_metrics_id_sqlite, db, { assert_ne!(metrics_id1, metrics_id2); }); -// #[test] -// fn test_fuzzy_like_string() { -// assert_eq!(DefaultDb::fuzzy_like_string("abcd"), "%a%b%c%d%"); -// assert_eq!(DefaultDb::fuzzy_like_string("x y"), "%x%y%"); -// assert_eq!(DefaultDb::fuzzy_like_string(" z "), "%z%"); -// } +#[test] +fn test_fuzzy_like_string() { + assert_eq!(Database::fuzzy_like_string("abcd"), "%a%b%c%d%"); + assert_eq!(Database::fuzzy_like_string("x y"), "%x%y%"); + assert_eq!(Database::fuzzy_like_string(" z "), "%z%"); +} // #[gpui::test] // async fn test_fuzzy_search_users() { From 2375741bdf0c289ddcbd8b906344db03efa93937 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 1 Dec 2022 10:09:53 +0100 Subject: [PATCH 079/109] Implement `db2::Database::fuzzy_search_users` --- crates/collab/src/db2.rs | 36 +++++++++++--- crates/collab/src/db2/tests.rs | 90 +++++++++++++++++----------------- 2 files changed, 75 insertions(+), 51 deletions(-) diff --git a/crates/collab/src/db2.rs b/crates/collab/src/db2.rs index 2e6b349497770939c135dd6b9b0808cb3570c543..b69f7f32a4c3cb2b10bf5e29ef767002cde6860f 100644 --- a/crates/collab/src/db2.rs +++ b/crates/collab/src/db2.rs @@ -20,8 +20,8 @@ use sea_orm::{ TransactionTrait, }; use sea_orm::{ - ActiveValue, ConnectionTrait, FromQueryResult, IntoActiveModel, JoinType, QueryOrder, - QuerySelect, + ActiveValue, ConnectionTrait, DatabaseBackend, FromQueryResult, IntoActiveModel, JoinType, + QueryOrder, QuerySelect, Statement, }; use sea_query::{Alias, Expr, OnConflict, Query}; use serde::{Deserialize, Serialize}; @@ -499,6 +499,30 @@ impl Database { result } + pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result> { + self.transact(|tx| async { + let tx = tx; + let like_string = Self::fuzzy_like_string(name_query); + let query = " + SELECT users.* + FROM users + WHERE github_login ILIKE $1 + ORDER BY github_login <-> $2 + LIMIT $3 + "; + + Ok(user::Entity::find() + .from_raw_sql(Statement::from_sql_and_values( + self.pool.get_database_backend(), + query.into(), + vec![like_string.into(), name_query.into(), limit.into()], + )) + .all(&tx) + .await?) + }) + .await + } + // projects pub async fn share_project( @@ -727,9 +751,9 @@ impl Database { let tx = self.pool.begin().await?; // In Postgres, serializable transactions are opt-in - if let sea_orm::DatabaseBackend::Postgres = self.pool.get_database_backend() { - tx.execute(sea_orm::Statement::from_string( - sea_orm::DatabaseBackend::Postgres, + if let DatabaseBackend::Postgres = self.pool.get_database_backend() { + tx.execute(Statement::from_string( + DatabaseBackend::Postgres, "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(), )) .await?; @@ -1047,7 +1071,7 @@ mod test { impl Drop for TestDb { fn drop(&mut self) { let db = self.db.take().unwrap(); - if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() { + if let DatabaseBackend::Postgres = db.pool.get_database_backend() { db.runtime.as_ref().unwrap().block_on(async { use util::ResultExt; let query = " diff --git a/crates/collab/src/db2/tests.rs b/crates/collab/src/db2/tests.rs index 45715a925e44137ef2444c50ba8dcc7c43f23763..527f70adb8ce42bab6e88f81252812604713a7a6 100644 --- a/crates/collab/src/db2/tests.rs +++ b/crates/collab/src/db2/tests.rs @@ -409,53 +409,53 @@ fn test_fuzzy_like_string() { assert_eq!(Database::fuzzy_like_string(" z "), "%z%"); } -// #[gpui::test] -// async fn test_fuzzy_search_users() { -// let test_db = PostgresTestDb::new(build_background_executor()); -// let db = test_db.db(); -// for (i, github_login) in [ -// "California", -// "colorado", -// "oregon", -// "washington", -// "florida", -// "delaware", -// "rhode-island", -// ] -// .into_iter() -// .enumerate() -// { -// db.create_user( -// &format!("{github_login}@example.com"), -// false, -// NewUserParams { -// github_login: github_login.into(), -// github_user_id: i as i32, -// invite_count: 0, -// }, -// ) -// .await -// .unwrap(); -// } +#[gpui::test] +async fn test_fuzzy_search_users() { + let test_db = TestDb::postgres(build_background_executor()); + let db = test_db.db(); + for (i, github_login) in [ + "California", + "colorado", + "oregon", + "washington", + "florida", + "delaware", + "rhode-island", + ] + .into_iter() + .enumerate() + { + db.create_user( + &format!("{github_login}@example.com"), + false, + NewUserParams { + github_login: github_login.into(), + github_user_id: i as i32, + invite_count: 0, + }, + ) + .await + .unwrap(); + } -// assert_eq!( -// fuzzy_search_user_names(db, "clr").await, -// &["colorado", "California"] -// ); -// assert_eq!( -// fuzzy_search_user_names(db, "ro").await, -// &["rhode-island", "colorado", "oregon"], -// ); + assert_eq!( + fuzzy_search_user_names(db, "clr").await, + &["colorado", "California"] + ); + assert_eq!( + fuzzy_search_user_names(db, "ro").await, + &["rhode-island", "colorado", "oregon"], + ); -// async fn fuzzy_search_user_names(db: &Db, query: &str) -> Vec { -// db.fuzzy_search_users(query, 10) -// .await -// .unwrap() -// .into_iter() -// .map(|user| user.github_login) -// .collect::>() -// } -// } + async fn fuzzy_search_user_names(db: &Database, query: &str) -> Vec { + db.fuzzy_search_users(query, 10) + .await + .unwrap() + .into_iter() + .map(|user| user.github_login) + .collect::>() + } +} // #[gpui::test] // async fn test_invite_codes() { From 4f864a20a7cfede662091f3f71c8ba2aba71d295 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 1 Dec 2022 11:10:51 +0100 Subject: [PATCH 080/109] Implement invite codes using sea-orm --- crates/collab/src/db2.rs | 220 ++++++++++++++++++ crates/collab/src/db2/signup.rs | 33 +++ crates/collab/src/db2/tests.rs | 386 ++++++++++++++++---------------- 3 files changed, 446 insertions(+), 193 deletions(-) create mode 100644 crates/collab/src/db2/signup.rs diff --git a/crates/collab/src/db2.rs b/crates/collab/src/db2.rs index b69f7f32a4c3cb2b10bf5e29ef767002cde6860f..75329f926894d8df21fcd52888822b68638a2f56 100644 --- a/crates/collab/src/db2.rs +++ b/crates/collab/src/db2.rs @@ -4,6 +4,7 @@ mod project; mod project_collaborator; mod room; mod room_participant; +mod signup; #[cfg(test)] mod tests; mod user; @@ -14,6 +15,7 @@ use anyhow::anyhow; use collections::HashMap; use dashmap::DashMap; use futures::StreamExt; +use hyper::StatusCode; use rpc::{proto, ConnectionId}; use sea_orm::{ entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr, @@ -34,6 +36,7 @@ use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc}; use tokio::sync::{Mutex, OwnedMutexGuard}; pub use contact::Contact; +pub use signup::Invite; pub use user::Model as User; pub struct Database { @@ -523,6 +526,222 @@ impl Database { .await } + // invite codes + + pub async fn create_invite_from_code( + &self, + code: &str, + email_address: &str, + device_id: Option<&str>, + ) -> Result { + self.transact(|tx| async move { + let existing_user = user::Entity::find() + .filter(user::Column::EmailAddress.eq(email_address)) + .one(&tx) + .await?; + + if existing_user.is_some() { + Err(anyhow!("email address is already in use"))?; + } + + let inviter = match user::Entity::find() + .filter(user::Column::InviteCode.eq(code)) + .one(&tx) + .await? + { + Some(inviter) => inviter, + None => { + return Err(Error::Http( + StatusCode::NOT_FOUND, + "invite code not found".to_string(), + ))? + } + }; + + if inviter.invite_count == 0 { + Err(Error::Http( + StatusCode::UNAUTHORIZED, + "no invites remaining".to_string(), + ))?; + } + + let signup = signup::Entity::insert(signup::ActiveModel { + email_address: ActiveValue::set(email_address.into()), + email_confirmation_code: ActiveValue::set(random_email_confirmation_code()), + email_confirmation_sent: ActiveValue::set(false), + inviting_user_id: ActiveValue::set(Some(inviter.id)), + platform_linux: ActiveValue::set(false), + platform_mac: ActiveValue::set(false), + platform_windows: ActiveValue::set(false), + platform_unknown: ActiveValue::set(true), + device_id: ActiveValue::set(device_id.map(|device_id| device_id.into())), + ..Default::default() + }) + .on_conflict( + OnConflict::column(signup::Column::EmailAddress) + .update_column(signup::Column::InvitingUserId) + .to_owned(), + ) + .exec_with_returning(&tx) + .await?; + tx.commit().await?; + + Ok(Invite { + email_address: signup.email_address, + email_confirmation_code: signup.email_confirmation_code, + }) + }) + .await + } + + pub async fn create_user_from_invite( + &self, + invite: &Invite, + user: NewUserParams, + ) -> Result> { + self.transact(|tx| async { + let tx = tx; + let signup = signup::Entity::find() + .filter( + signup::Column::EmailAddress + .eq(invite.email_address.as_str()) + .and( + signup::Column::EmailConfirmationCode + .eq(invite.email_confirmation_code.as_str()), + ), + ) + .one(&tx) + .await? + .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?; + + if signup.user_id.is_some() { + return Ok(None); + } + + let user = user::Entity::insert(user::ActiveModel { + email_address: ActiveValue::set(Some(invite.email_address.clone())), + github_login: ActiveValue::set(user.github_login.clone()), + github_user_id: ActiveValue::set(Some(user.github_user_id)), + admin: ActiveValue::set(false), + invite_count: ActiveValue::set(user.invite_count), + invite_code: ActiveValue::set(Some(random_invite_code())), + metrics_id: ActiveValue::set(Uuid::new_v4()), + ..Default::default() + }) + .on_conflict( + OnConflict::column(user::Column::GithubLogin) + .update_columns([ + user::Column::EmailAddress, + user::Column::GithubUserId, + user::Column::Admin, + ]) + .to_owned(), + ) + .exec_with_returning(&tx) + .await?; + + let mut signup = signup.into_active_model(); + signup.user_id = ActiveValue::set(Some(user.id)); + let signup = signup.update(&tx).await?; + + if let Some(inviting_user_id) = signup.inviting_user_id { + let result = user::Entity::update_many() + .filter( + user::Column::Id + .eq(inviting_user_id) + .and(user::Column::InviteCount.gt(0)), + ) + .col_expr( + user::Column::InviteCount, + Expr::col(user::Column::InviteCount).sub(1), + ) + .exec(&tx) + .await?; + + if result.rows_affected == 0 { + Err(Error::Http( + StatusCode::UNAUTHORIZED, + "no invites remaining".to_string(), + ))?; + } + + contact::Entity::insert(contact::ActiveModel { + user_id_a: ActiveValue::set(inviting_user_id), + user_id_b: ActiveValue::set(user.id), + a_to_b: ActiveValue::set(true), + should_notify: ActiveValue::set(true), + accepted: ActiveValue::set(true), + ..Default::default() + }) + .on_conflict(OnConflict::new().do_nothing().to_owned()) + .exec_without_returning(&tx) + .await?; + } + + tx.commit().await?; + Ok(Some(NewUserResult { + user_id: user.id, + metrics_id: user.metrics_id.to_string(), + inviting_user_id: signup.inviting_user_id, + signup_device_id: signup.device_id, + })) + }) + .await + } + + pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> { + self.transact(|tx| async move { + if count > 0 { + user::Entity::update_many() + .filter( + user::Column::Id + .eq(id) + .and(user::Column::InviteCode.is_null()), + ) + .col_expr(user::Column::InviteCode, random_invite_code().into()) + .exec(&tx) + .await?; + } + + user::Entity::update_many() + .filter(user::Column::Id.eq(id)) + .col_expr(user::Column::InviteCount, count.into()) + .exec(&tx) + .await?; + tx.commit().await?; + Ok(()) + }) + .await + } + + pub async fn get_invite_code_for_user(&self, id: UserId) -> Result> { + self.transact(|tx| async move { + match user::Entity::find_by_id(id).one(&tx).await? { + Some(user) if user.invite_code.is_some() => { + Ok(Some((user.invite_code.unwrap(), user.invite_count as u32))) + } + _ => Ok(None), + } + }) + .await + } + + pub async fn get_user_for_invite_code(&self, code: &str) -> Result { + self.transact(|tx| async move { + user::Entity::find() + .filter(user::Column::InviteCode.eq(code)) + .one(&tx) + .await? + .ok_or_else(|| { + Error::Http( + StatusCode::NOT_FOUND, + "that invite code does not exist".to_string(), + ) + }) + }) + .await + } + // projects pub async fn share_project( @@ -966,6 +1185,7 @@ id_type!(RoomId); id_type!(RoomParticipantId); id_type!(ProjectId); id_type!(ProjectCollaboratorId); +id_type!(SignupId); id_type!(WorktreeId); #[cfg(test)] diff --git a/crates/collab/src/db2/signup.rs b/crates/collab/src/db2/signup.rs new file mode 100644 index 0000000000000000000000000000000000000000..ad0aa5eb824b64bafc491a7b4125f333096f8210 --- /dev/null +++ b/crates/collab/src/db2/signup.rs @@ -0,0 +1,33 @@ +use super::{SignupId, UserId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "signups")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: SignupId, + pub email_address: String, + pub email_confirmation_code: String, + pub email_confirmation_sent: bool, + pub created_at: DateTime, + pub device_id: Option, + pub user_id: Option, + pub inviting_user_id: Option, + pub platform_mac: bool, + pub platform_linux: bool, + pub platform_windows: bool, + pub platform_unknown: bool, + pub editor_features: Option, + pub programming_languages: Option, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} + +#[derive(Debug)] +pub struct Invite { + pub email_address: String, + pub email_confirmation_code: String, +} diff --git a/crates/collab/src/db2/tests.rs b/crates/collab/src/db2/tests.rs index 527f70adb8ce42bab6e88f81252812604713a7a6..468d0074d4fe28bd87883c263f503fee7f68fdd3 100644 --- a/crates/collab/src/db2/tests.rs +++ b/crates/collab/src/db2/tests.rs @@ -457,210 +457,210 @@ async fn test_fuzzy_search_users() { } } -// #[gpui::test] -// async fn test_invite_codes() { -// let test_db = PostgresTestDb::new(build_background_executor()); -// let db = test_db.db(); +#[gpui::test] +async fn test_invite_codes() { + let test_db = TestDb::postgres(build_background_executor()); + let db = test_db.db(); -// let NewUserResult { user_id: user1, .. } = db -// .create_user( -// "user1@example.com", -// false, -// NewUserParams { -// github_login: "user1".into(), -// github_user_id: 0, -// invite_count: 0, -// }, -// ) -// .await -// .unwrap(); + let NewUserResult { user_id: user1, .. } = db + .create_user( + "user1@example.com", + false, + NewUserParams { + github_login: "user1".into(), + github_user_id: 0, + invite_count: 0, + }, + ) + .await + .unwrap(); -// // Initially, user 1 has no invite code -// assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None); + // Initially, user 1 has no invite code + assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None); -// // Setting invite count to 0 when no code is assigned does not assign a new code -// db.set_invite_count_for_user(user1, 0).await.unwrap(); -// assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none()); + // Setting invite count to 0 when no code is assigned does not assign a new code + db.set_invite_count_for_user(user1, 0).await.unwrap(); + assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none()); -// // User 1 creates an invite code that can be used twice. -// db.set_invite_count_for_user(user1, 2).await.unwrap(); -// let (invite_code, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); -// assert_eq!(invite_count, 2); + // User 1 creates an invite code that can be used twice. + db.set_invite_count_for_user(user1, 2).await.unwrap(); + let (invite_code, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); + assert_eq!(invite_count, 2); -// // User 2 redeems the invite code and becomes a contact of user 1. -// let user2_invite = db -// .create_invite_from_code(&invite_code, "user2@example.com", Some("user-2-device-id")) -// .await -// .unwrap(); -// let NewUserResult { -// user_id: user2, -// inviting_user_id, -// signup_device_id, -// metrics_id, -// } = db -// .create_user_from_invite( -// &user2_invite, -// NewUserParams { -// github_login: "user2".into(), -// github_user_id: 2, -// invite_count: 7, -// }, -// ) -// .await -// .unwrap() -// .unwrap(); -// let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); -// assert_eq!(invite_count, 1); -// assert_eq!(inviting_user_id, Some(user1)); -// assert_eq!(signup_device_id.unwrap(), "user-2-device-id"); -// assert_eq!(db.get_user_metrics_id(user2).await.unwrap(), metrics_id); -// assert_eq!( -// db.get_contacts(user1).await.unwrap(), -// [Contact::Accepted { -// user_id: user2, -// should_notify: true, -// busy: false, -// }] -// ); -// assert_eq!( -// db.get_contacts(user2).await.unwrap(), -// [Contact::Accepted { -// user_id: user1, -// should_notify: false, -// busy: false, -// }] -// ); -// assert_eq!( -// db.get_invite_code_for_user(user2).await.unwrap().unwrap().1, -// 7 -// ); + // User 2 redeems the invite code and becomes a contact of user 1. + let user2_invite = db + .create_invite_from_code(&invite_code, "user2@example.com", Some("user-2-device-id")) + .await + .unwrap(); + let NewUserResult { + user_id: user2, + inviting_user_id, + signup_device_id, + metrics_id, + } = db + .create_user_from_invite( + &user2_invite, + NewUserParams { + github_login: "user2".into(), + github_user_id: 2, + invite_count: 7, + }, + ) + .await + .unwrap() + .unwrap(); + let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); + assert_eq!(invite_count, 1); + assert_eq!(inviting_user_id, Some(user1)); + assert_eq!(signup_device_id.unwrap(), "user-2-device-id"); + assert_eq!(db.get_user_metrics_id(user2).await.unwrap(), metrics_id); + assert_eq!( + db.get_contacts(user1).await.unwrap(), + [Contact::Accepted { + user_id: user2, + should_notify: true, + busy: false, + }] + ); + assert_eq!( + db.get_contacts(user2).await.unwrap(), + [Contact::Accepted { + user_id: user1, + should_notify: false, + busy: false, + }] + ); + assert_eq!( + db.get_invite_code_for_user(user2).await.unwrap().unwrap().1, + 7 + ); -// // User 3 redeems the invite code and becomes a contact of user 1. -// let user3_invite = db -// .create_invite_from_code(&invite_code, "user3@example.com", None) -// .await -// .unwrap(); -// let NewUserResult { -// user_id: user3, -// inviting_user_id, -// signup_device_id, -// .. -// } = db -// .create_user_from_invite( -// &user3_invite, -// NewUserParams { -// github_login: "user-3".into(), -// github_user_id: 3, -// invite_count: 3, -// }, -// ) -// .await -// .unwrap() -// .unwrap(); -// let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); -// assert_eq!(invite_count, 0); -// assert_eq!(inviting_user_id, Some(user1)); -// assert!(signup_device_id.is_none()); -// assert_eq!( -// db.get_contacts(user1).await.unwrap(), -// [ -// Contact::Accepted { -// user_id: user2, -// should_notify: true, -// busy: false, -// }, -// Contact::Accepted { -// user_id: user3, -// should_notify: true, -// busy: false, -// } -// ] -// ); -// assert_eq!( -// db.get_contacts(user3).await.unwrap(), -// [Contact::Accepted { -// user_id: user1, -// should_notify: false, -// busy: false, -// }] -// ); -// assert_eq!( -// db.get_invite_code_for_user(user3).await.unwrap().unwrap().1, -// 3 -// ); + // User 3 redeems the invite code and becomes a contact of user 1. + let user3_invite = db + .create_invite_from_code(&invite_code, "user3@example.com", None) + .await + .unwrap(); + let NewUserResult { + user_id: user3, + inviting_user_id, + signup_device_id, + .. + } = db + .create_user_from_invite( + &user3_invite, + NewUserParams { + github_login: "user-3".into(), + github_user_id: 3, + invite_count: 3, + }, + ) + .await + .unwrap() + .unwrap(); + let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); + assert_eq!(invite_count, 0); + assert_eq!(inviting_user_id, Some(user1)); + assert!(signup_device_id.is_none()); + assert_eq!( + db.get_contacts(user1).await.unwrap(), + [ + Contact::Accepted { + user_id: user2, + should_notify: true, + busy: false, + }, + Contact::Accepted { + user_id: user3, + should_notify: true, + busy: false, + } + ] + ); + assert_eq!( + db.get_contacts(user3).await.unwrap(), + [Contact::Accepted { + user_id: user1, + should_notify: false, + busy: false, + }] + ); + assert_eq!( + db.get_invite_code_for_user(user3).await.unwrap().unwrap().1, + 3 + ); -// // Trying to reedem the code for the third time results in an error. -// db.create_invite_from_code(&invite_code, "user4@example.com", Some("user-4-device-id")) -// .await -// .unwrap_err(); + // Trying to reedem the code for the third time results in an error. + db.create_invite_from_code(&invite_code, "user4@example.com", Some("user-4-device-id")) + .await + .unwrap_err(); -// // Invite count can be updated after the code has been created. -// db.set_invite_count_for_user(user1, 2).await.unwrap(); -// let (latest_code, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); -// assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0 -// assert_eq!(invite_count, 2); + // Invite count can be updated after the code has been created. + db.set_invite_count_for_user(user1, 2).await.unwrap(); + let (latest_code, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); + assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0 + assert_eq!(invite_count, 2); -// // User 4 can now redeem the invite code and becomes a contact of user 1. -// let user4_invite = db -// .create_invite_from_code(&invite_code, "user4@example.com", Some("user-4-device-id")) -// .await -// .unwrap(); -// let user4 = db -// .create_user_from_invite( -// &user4_invite, -// NewUserParams { -// github_login: "user-4".into(), -// github_user_id: 4, -// invite_count: 5, -// }, -// ) -// .await -// .unwrap() -// .unwrap() -// .user_id; + // User 4 can now redeem the invite code and becomes a contact of user 1. + let user4_invite = db + .create_invite_from_code(&invite_code, "user4@example.com", Some("user-4-device-id")) + .await + .unwrap(); + let user4 = db + .create_user_from_invite( + &user4_invite, + NewUserParams { + github_login: "user-4".into(), + github_user_id: 4, + invite_count: 5, + }, + ) + .await + .unwrap() + .unwrap() + .user_id; -// let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); -// assert_eq!(invite_count, 1); -// assert_eq!( -// db.get_contacts(user1).await.unwrap(), -// [ -// Contact::Accepted { -// user_id: user2, -// should_notify: true, -// busy: false, -// }, -// Contact::Accepted { -// user_id: user3, -// should_notify: true, -// busy: false, -// }, -// Contact::Accepted { -// user_id: user4, -// should_notify: true, -// busy: false, -// } -// ] -// ); -// assert_eq!( -// db.get_contacts(user4).await.unwrap(), -// [Contact::Accepted { -// user_id: user1, -// should_notify: false, -// busy: false, -// }] -// ); -// assert_eq!( -// db.get_invite_code_for_user(user4).await.unwrap().unwrap().1, -// 5 -// ); + let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); + assert_eq!(invite_count, 1); + assert_eq!( + db.get_contacts(user1).await.unwrap(), + [ + Contact::Accepted { + user_id: user2, + should_notify: true, + busy: false, + }, + Contact::Accepted { + user_id: user3, + should_notify: true, + busy: false, + }, + Contact::Accepted { + user_id: user4, + should_notify: true, + busy: false, + } + ] + ); + assert_eq!( + db.get_contacts(user4).await.unwrap(), + [Contact::Accepted { + user_id: user1, + should_notify: false, + busy: false, + }] + ); + assert_eq!( + db.get_invite_code_for_user(user4).await.unwrap().unwrap().1, + 5 + ); -// // An existing user cannot redeem invite codes. -// db.create_invite_from_code(&invite_code, "user2@example.com", Some("user-2-device-id")) -// .await -// .unwrap_err(); -// let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); -// assert_eq!(invite_count, 1); -// } + // An existing user cannot redeem invite codes. + db.create_invite_from_code(&invite_code, "user2@example.com", Some("user-2-device-id")) + .await + .unwrap_err(); + let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); + assert_eq!(invite_count, 1); +} // #[gpui::test] // async fn test_signups() { From 19d14737bfe5b6a249236586e5e81f82ac6188d8 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 1 Dec 2022 11:58:07 +0100 Subject: [PATCH 081/109] Implement signups using sea-orm --- crates/collab/Cargo.toml | 2 +- crates/collab/src/db2.rs | 102 ++++++++++- crates/collab/src/db2/signup.rs | 29 +++- crates/collab/src/db2/tests.rs | 290 ++++++++++++++++---------------- 4 files changed, 271 insertions(+), 152 deletions(-) diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index a268bdd7b096b9c9ce22aea4ea30b09485b8446b..4cb91ad12deba99cbe5a5cb431018fe106c3a659 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -36,7 +36,7 @@ prometheus = "0.13" rand = "0.8" reqwest = { version = "0.11", features = ["json"], optional = true } scrypt = "0.7" -sea-orm = { version = "0.10", features = ["sqlx-postgres", "runtime-tokio-rustls"] } +sea-orm = { version = "0.10", features = ["sqlx-postgres", "postgres-array", "runtime-tokio-rustls"] } sea-query = { version = "0.27", features = ["derive"] } sea-query-binder = { version = "0.2", features = ["sqlx-postgres"] } serde = { version = "1.0", features = ["derive", "rc"] } diff --git a/crates/collab/src/db2.rs b/crates/collab/src/db2.rs index 75329f926894d8df21fcd52888822b68638a2f56..3aa21c60593aaf4a60189076b7f298821a64e7da 100644 --- a/crates/collab/src/db2.rs +++ b/crates/collab/src/db2.rs @@ -36,7 +36,7 @@ use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc}; use tokio::sync::{Mutex, OwnedMutexGuard}; pub use contact::Contact; -pub use signup::Invite; +pub use signup::{Invite, NewSignup, WaitlistSummary}; pub use user::Model as User; pub struct Database { @@ -140,6 +140,11 @@ impl Database { .await } + pub async fn get_user_by_id(&self, id: UserId) -> Result> { + self.transact(|tx| async move { Ok(user::Entity::find_by_id(id).one(&tx).await?) }) + .await + } + pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { self.transact(|tx| async { let tx = tx; @@ -322,7 +327,7 @@ impl Database { } pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { - self.transact(|mut tx| async move { + self.transact(|tx| async move { let (id_a, id_b, a_to_b) = if sender_id < receiver_id { (sender_id, receiver_id, true) } else { @@ -526,6 +531,99 @@ impl Database { .await } + // signups + + pub async fn create_signup(&self, signup: NewSignup) -> Result<()> { + self.transact(|tx| async { + signup::ActiveModel { + email_address: ActiveValue::set(signup.email_address.clone()), + email_confirmation_code: ActiveValue::set(random_email_confirmation_code()), + email_confirmation_sent: ActiveValue::set(false), + platform_mac: ActiveValue::set(signup.platform_mac), + platform_windows: ActiveValue::set(signup.platform_windows), + platform_linux: ActiveValue::set(signup.platform_linux), + platform_unknown: ActiveValue::set(false), + editor_features: ActiveValue::set(Some(signup.editor_features.clone())), + programming_languages: ActiveValue::set(Some(signup.programming_languages.clone())), + device_id: ActiveValue::set(signup.device_id.clone()), + ..Default::default() + } + .insert(&tx) + .await?; + tx.commit().await?; + Ok(()) + }) + .await + } + + pub async fn get_waitlist_summary(&self) -> Result { + self.transact(|tx| async move { + let query = " + SELECT + COUNT(*) as count, + COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count, + COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count, + COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count, + COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count + FROM ( + SELECT * + FROM signups + WHERE + NOT email_confirmation_sent + ) AS unsent + "; + Ok( + WaitlistSummary::find_by_statement(Statement::from_sql_and_values( + self.pool.get_database_backend(), + query.into(), + vec![], + )) + .one(&tx) + .await? + .ok_or_else(|| anyhow!("invalid result"))?, + ) + }) + .await + } + + pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> { + let emails = invites + .iter() + .map(|s| s.email_address.as_str()) + .collect::>(); + self.transact(|tx| async { + signup::Entity::update_many() + .filter(signup::Column::EmailAddress.is_in(emails.iter().copied())) + .col_expr(signup::Column::EmailConfirmationSent, true.into()) + .exec(&tx) + .await?; + tx.commit().await?; + Ok(()) + }) + .await + } + + pub async fn get_unsent_invites(&self, count: usize) -> Result> { + self.transact(|tx| async move { + Ok(signup::Entity::find() + .select_only() + .column(signup::Column::EmailAddress) + .column(signup::Column::EmailConfirmationCode) + .filter( + signup::Column::EmailConfirmationSent.eq(false).and( + signup::Column::PlatformMac + .eq(true) + .or(signup::Column::PlatformUnknown.eq(true)), + ), + ) + .limit(count as u64) + .into_model() + .all(&tx) + .await?) + }) + .await + } + // invite codes pub async fn create_invite_from_code( diff --git a/crates/collab/src/db2/signup.rs b/crates/collab/src/db2/signup.rs index ad0aa5eb824b64bafc491a7b4125f333096f8210..8fab8daa3621ebe93a08ed74fc02c47a7fdfae61 100644 --- a/crates/collab/src/db2/signup.rs +++ b/crates/collab/src/db2/signup.rs @@ -1,5 +1,6 @@ use super::{SignupId, UserId}; -use sea_orm::entity::prelude::*; +use sea_orm::{entity::prelude::*, FromQueryResult}; +use serde::{Deserialize, Serialize}; #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] #[sea_orm(table_name = "signups")] @@ -17,8 +18,8 @@ pub struct Model { pub platform_linux: bool, pub platform_windows: bool, pub platform_unknown: bool, - pub editor_features: Option, - pub programming_languages: Option, + pub editor_features: Option>, + pub programming_languages: Option>, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] @@ -26,8 +27,28 @@ pub enum Relation {} impl ActiveModelBehavior for ActiveModel {} -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq, FromQueryResult)] pub struct Invite { pub email_address: String, pub email_confirmation_code: String, } + +#[derive(Clone, Deserialize)] +pub struct NewSignup { + pub email_address: String, + pub platform_mac: bool, + pub platform_windows: bool, + pub platform_linux: bool, + pub editor_features: Vec, + pub programming_languages: Vec, + pub device_id: Option, +} + +#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromQueryResult)] +pub struct WaitlistSummary { + pub count: i64, + pub linux_count: i64, + pub mac_count: i64, + pub windows_count: i64, + pub unknown_count: i64, +} diff --git a/crates/collab/src/db2/tests.rs b/crates/collab/src/db2/tests.rs index 468d0074d4fe28bd87883c263f503fee7f68fdd3..b276bd5057b7282815a4c21eeea00fd691eecff5 100644 --- a/crates/collab/src/db2/tests.rs +++ b/crates/collab/src/db2/tests.rs @@ -662,151 +662,151 @@ async fn test_invite_codes() { assert_eq!(invite_count, 1); } -// #[gpui::test] -// async fn test_signups() { -// let test_db = PostgresTestDb::new(build_background_executor()); -// let db = test_db.db(); - -// // people sign up on the waitlist -// for i in 0..8 { -// db.create_signup(Signup { -// email_address: format!("person-{i}@example.com"), -// platform_mac: true, -// platform_linux: i % 2 == 0, -// platform_windows: i % 4 == 0, -// editor_features: vec!["speed".into()], -// programming_languages: vec!["rust".into(), "c".into()], -// device_id: Some(format!("device_id_{i}")), -// }) -// .await -// .unwrap(); -// } - -// assert_eq!( -// db.get_waitlist_summary().await.unwrap(), -// WaitlistSummary { -// count: 8, -// mac_count: 8, -// linux_count: 4, -// windows_count: 2, -// unknown_count: 0, -// } -// ); - -// // retrieve the next batch of signup emails to send -// let signups_batch1 = db.get_unsent_invites(3).await.unwrap(); -// let addresses = signups_batch1 -// .iter() -// .map(|s| &s.email_address) -// .collect::>(); -// assert_eq!( -// addresses, -// &[ -// "person-0@example.com", -// "person-1@example.com", -// "person-2@example.com" -// ] -// ); -// assert_ne!( -// signups_batch1[0].email_confirmation_code, -// signups_batch1[1].email_confirmation_code -// ); - -// // the waitlist isn't updated until we record that the emails -// // were successfully sent. -// let signups_batch = db.get_unsent_invites(3).await.unwrap(); -// assert_eq!(signups_batch, signups_batch1); - -// // once the emails go out, we can retrieve the next batch -// // of signups. -// db.record_sent_invites(&signups_batch1).await.unwrap(); -// let signups_batch2 = db.get_unsent_invites(3).await.unwrap(); -// let addresses = signups_batch2 -// .iter() -// .map(|s| &s.email_address) -// .collect::>(); -// assert_eq!( -// addresses, -// &[ -// "person-3@example.com", -// "person-4@example.com", -// "person-5@example.com" -// ] -// ); - -// // the sent invites are excluded from the summary. -// assert_eq!( -// db.get_waitlist_summary().await.unwrap(), -// WaitlistSummary { -// count: 5, -// mac_count: 5, -// linux_count: 2, -// windows_count: 1, -// unknown_count: 0, -// } -// ); - -// // user completes the signup process by providing their -// // github account. -// let NewUserResult { -// user_id, -// inviting_user_id, -// signup_device_id, -// .. -// } = db -// .create_user_from_invite( -// &Invite { -// email_address: signups_batch1[0].email_address.clone(), -// email_confirmation_code: signups_batch1[0].email_confirmation_code.clone(), -// }, -// NewUserParams { -// github_login: "person-0".into(), -// github_user_id: 0, -// invite_count: 5, -// }, -// ) -// .await -// .unwrap() -// .unwrap(); -// let user = db.get_user_by_id(user_id).await.unwrap().unwrap(); -// assert!(inviting_user_id.is_none()); -// assert_eq!(user.github_login, "person-0"); -// assert_eq!(user.email_address.as_deref(), Some("person-0@example.com")); -// assert_eq!(user.invite_count, 5); -// assert_eq!(signup_device_id.unwrap(), "device_id_0"); - -// // cannot redeem the same signup again. -// assert!(db -// .create_user_from_invite( -// &Invite { -// email_address: signups_batch1[0].email_address.clone(), -// email_confirmation_code: signups_batch1[0].email_confirmation_code.clone(), -// }, -// NewUserParams { -// github_login: "some-other-github_account".into(), -// github_user_id: 1, -// invite_count: 5, -// }, -// ) -// .await -// .unwrap() -// .is_none()); - -// // cannot redeem a signup with the wrong confirmation code. -// db.create_user_from_invite( -// &Invite { -// email_address: signups_batch1[1].email_address.clone(), -// email_confirmation_code: "the-wrong-code".to_string(), -// }, -// NewUserParams { -// github_login: "person-1".into(), -// github_user_id: 2, -// invite_count: 5, -// }, -// ) -// .await -// .unwrap_err(); -// } +#[gpui::test] +async fn test_signups() { + let test_db = TestDb::postgres(build_background_executor()); + let db = test_db.db(); + + // people sign up on the waitlist + for i in 0..8 { + db.create_signup(NewSignup { + email_address: format!("person-{i}@example.com"), + platform_mac: true, + platform_linux: i % 2 == 0, + platform_windows: i % 4 == 0, + editor_features: vec!["speed".into()], + programming_languages: vec!["rust".into(), "c".into()], + device_id: Some(format!("device_id_{i}")), + }) + .await + .unwrap(); + } + + assert_eq!( + db.get_waitlist_summary().await.unwrap(), + WaitlistSummary { + count: 8, + mac_count: 8, + linux_count: 4, + windows_count: 2, + unknown_count: 0, + } + ); + + // retrieve the next batch of signup emails to send + let signups_batch1 = db.get_unsent_invites(3).await.unwrap(); + let addresses = signups_batch1 + .iter() + .map(|s| &s.email_address) + .collect::>(); + assert_eq!( + addresses, + &[ + "person-0@example.com", + "person-1@example.com", + "person-2@example.com" + ] + ); + assert_ne!( + signups_batch1[0].email_confirmation_code, + signups_batch1[1].email_confirmation_code + ); + + // the waitlist isn't updated until we record that the emails + // were successfully sent. + let signups_batch = db.get_unsent_invites(3).await.unwrap(); + assert_eq!(signups_batch, signups_batch1); + + // once the emails go out, we can retrieve the next batch + // of signups. + db.record_sent_invites(&signups_batch1).await.unwrap(); + let signups_batch2 = db.get_unsent_invites(3).await.unwrap(); + let addresses = signups_batch2 + .iter() + .map(|s| &s.email_address) + .collect::>(); + assert_eq!( + addresses, + &[ + "person-3@example.com", + "person-4@example.com", + "person-5@example.com" + ] + ); + + // the sent invites are excluded from the summary. + assert_eq!( + db.get_waitlist_summary().await.unwrap(), + WaitlistSummary { + count: 5, + mac_count: 5, + linux_count: 2, + windows_count: 1, + unknown_count: 0, + } + ); + + // user completes the signup process by providing their + // github account. + let NewUserResult { + user_id, + inviting_user_id, + signup_device_id, + .. + } = db + .create_user_from_invite( + &Invite { + email_address: signups_batch1[0].email_address.clone(), + email_confirmation_code: signups_batch1[0].email_confirmation_code.clone(), + }, + NewUserParams { + github_login: "person-0".into(), + github_user_id: 0, + invite_count: 5, + }, + ) + .await + .unwrap() + .unwrap(); + let user = db.get_user_by_id(user_id).await.unwrap().unwrap(); + assert!(inviting_user_id.is_none()); + assert_eq!(user.github_login, "person-0"); + assert_eq!(user.email_address.as_deref(), Some("person-0@example.com")); + assert_eq!(user.invite_count, 5); + assert_eq!(signup_device_id.unwrap(), "device_id_0"); + + // cannot redeem the same signup again. + assert!(db + .create_user_from_invite( + &Invite { + email_address: signups_batch1[0].email_address.clone(), + email_confirmation_code: signups_batch1[0].email_confirmation_code.clone(), + }, + NewUserParams { + github_login: "some-other-github_account".into(), + github_user_id: 1, + invite_count: 5, + }, + ) + .await + .unwrap() + .is_none()); + + // cannot redeem a signup with the wrong confirmation code. + db.create_user_from_invite( + &Invite { + email_address: signups_batch1[1].email_address.clone(), + email_confirmation_code: "the-wrong-code".to_string(), + }, + NewUserParams { + github_login: "person-1".into(), + github_user_id: 2, + invite_count: 5, + }, + ) + .await + .unwrap_err(); +} fn build_background_executor() -> Arc { Deterministic::new(0).build_background() From d2385bd6a0d90771cec772267916b3a7f566ea35 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 1 Dec 2022 14:40:37 +0100 Subject: [PATCH 082/109] Start using the new sea-orm backed database --- Cargo.lock | 1 - crates/collab/Cargo.toml | 4 +- crates/collab/src/api.rs | 4 +- crates/collab/src/auth.rs | 2 +- crates/collab/src/db.rs | 3628 +++++------------ crates/collab/src/{db2 => db}/access_token.rs | 0 crates/collab/src/{db2 => db}/contact.rs | 0 crates/collab/src/{db2 => db}/project.rs | 0 .../src/{db2 => db}/project_collaborator.rs | 0 crates/collab/src/{db2 => db}/room.rs | 0 .../src/{db2 => db}/room_participant.rs | 0 crates/collab/src/db/schema.rs | 43 - crates/collab/src/{db2 => db}/signup.rs | 2 +- crates/collab/src/db/tests.rs | 35 +- crates/collab/src/{db2 => db}/user.rs | 4 +- crates/collab/src/{db2 => db}/worktree.rs | 0 crates/collab/src/db2.rs | 1416 ------- crates/collab/src/db2/tests.rs | 813 ---- crates/collab/src/integration_tests.rs | 4 +- crates/collab/src/main.rs | 13 +- crates/collab/src/rpc.rs | 6 +- 21 files changed, 1102 insertions(+), 4873 deletions(-) rename crates/collab/src/{db2 => db}/access_token.rs (100%) rename crates/collab/src/{db2 => db}/contact.rs (100%) rename crates/collab/src/{db2 => db}/project.rs (100%) rename crates/collab/src/{db2 => db}/project_collaborator.rs (100%) rename crates/collab/src/{db2 => db}/room.rs (100%) rename crates/collab/src/{db2 => db}/room_participant.rs (100%) delete mode 100644 crates/collab/src/db/schema.rs rename crates/collab/src/{db2 => db}/signup.rs (95%) rename crates/collab/src/{db2 => db}/user.rs (93%) rename crates/collab/src/{db2 => db}/worktree.rs (100%) delete mode 100644 crates/collab/src/db2.rs delete mode 100644 crates/collab/src/db2/tests.rs diff --git a/Cargo.lock b/Cargo.lock index 7b09775f2a46bad44cfcd2d98645bad8640828e1..590835a49bcfeb5b05a99ee5e00d7d1efff8d0e4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1159,7 +1159,6 @@ dependencies = [ "scrypt", "sea-orm", "sea-query", - "sea-query-binder", "serde", "serde_json", "settings", diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 4cb91ad12deba99cbe5a5cb431018fe106c3a659..66f426839cc4e2d6139e4c53004b38c3bf8d13f2 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -37,8 +37,7 @@ rand = "0.8" reqwest = { version = "0.11", features = ["json"], optional = true } scrypt = "0.7" sea-orm = { version = "0.10", features = ["sqlx-postgres", "postgres-array", "runtime-tokio-rustls"] } -sea-query = { version = "0.27", features = ["derive"] } -sea-query-binder = { version = "0.2", features = ["sqlx-postgres"] } +sea-query = "0.27" serde = { version = "1.0", features = ["derive", "rc"] } serde_json = "1.0" sha-1 = "0.9" @@ -76,7 +75,6 @@ log = { version = "0.4.16", features = ["kv_unstable_serde"] } util = { path = "../util" } lazy_static = "1.4" sea-orm = { version = "0.10", features = ["sqlx-sqlite"] } -sea-query-binder = { version = "0.2", features = ["sqlx-sqlite"] } serde_json = { version = "1.0", features = ["preserve_order"] } sqlx = { version = "0.6", features = ["sqlite"] } unindent = "0.1" diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index 5fcdc5fcfdf59a983d3d4c04d98242eb3d97fa41..bf183edf5440460cbd9f1d6043277266d346c8b5 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -1,6 +1,6 @@ use crate::{ auth, - db::{Invite, NewUserParams, Signup, User, UserId, WaitlistSummary}, + db::{Invite, NewSignup, NewUserParams, User, UserId, WaitlistSummary}, rpc::{self, ResultExt}, AppState, Error, Result, }; @@ -335,7 +335,7 @@ async fn get_user_for_invite_code( } async fn create_signup( - Json(params): Json, + Json(params): Json, Extension(app): Extension>, ) -> Result<()> { app.db.create_signup(params).await?; diff --git a/crates/collab/src/auth.rs b/crates/collab/src/auth.rs index 63f032f7e65d17f454d603b26c6206c81eacdf65..0c9cf33a6b94b369a9ea47e92e254ffd87e151ab 100644 --- a/crates/collab/src/auth.rs +++ b/crates/collab/src/auth.rs @@ -75,7 +75,7 @@ pub async fn validate_header(mut req: Request, next: Next) -> impl Into const MAX_ACCESS_TOKENS_TO_STORE: usize = 8; -pub async fn create_access_token(db: &db::DefaultDb, user_id: UserId) -> Result { +pub async fn create_access_token(db: &db::Database, user_id: UserId) -> Result { let access_token = rpc::auth::random_token(); let access_token_hash = hash_access_token(&access_token).context("failed to hash access token")?; diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 044d4ef8d7790f48491e0d4797080f78073662ce..d89d041f2a832d17201e0c4f23d1d76aed32a5ef 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1,42 +1,44 @@ -mod schema; +mod access_token; +mod contact; +mod project; +mod project_collaborator; +mod room; +mod room_participant; +mod signup; #[cfg(test)] mod tests; +mod user; +mod worktree; use crate::{Error, Result}; use anyhow::anyhow; -use axum::http::StatusCode; -use collections::{BTreeMap, HashMap, HashSet}; +use collections::HashMap; +pub use contact::Contact; use dashmap::DashMap; -use futures::{future::BoxFuture, FutureExt, StreamExt}; +use futures::StreamExt; +use hyper::StatusCode; use rpc::{proto, ConnectionId}; -use sea_query::{Expr, Query}; -use sea_query_binder::SqlxBinder; -use serde::{Deserialize, Serialize}; -use sqlx::{ - migrate::{Migrate as _, Migration, MigrationSource}, - types::Uuid, - FromRow, -}; -use std::{ - future::Future, - marker::PhantomData, - ops::{Deref, DerefMut}, - path::Path, - rc::Rc, - sync::Arc, - time::Duration, +pub use sea_orm::ConnectOptions; +use sea_orm::{ + entity::prelude::*, ActiveValue, ConnectionTrait, DatabaseBackend, DatabaseConnection, + DatabaseTransaction, DbErr, FromQueryResult, IntoActiveModel, JoinType, QueryOrder, + QuerySelect, Statement, TransactionTrait, }; -use time::{OffsetDateTime, PrimitiveDateTime}; +use sea_query::{Alias, Expr, OnConflict, Query}; +use serde::{Deserialize, Serialize}; +pub use signup::{Invite, NewSignup, WaitlistSummary}; +use sqlx::migrate::{Migrate, Migration, MigrationSource}; +use sqlx::Connection; +use std::ops::{Deref, DerefMut}; +use std::path::Path; +use std::time::Duration; +use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc}; use tokio::sync::{Mutex, OwnedMutexGuard}; +pub use user::Model as User; -#[cfg(test)] -pub type DefaultDb = Db; - -#[cfg(not(test))] -pub type DefaultDb = Db; - -pub struct Db { - pool: sqlx::Pool, +pub struct Database { + options: ConnectOptions, + pool: DatabaseConnection, rooms: DashMap>>, #[cfg(test)] background: Option>, @@ -44,142 +46,61 @@ pub struct Db { runtime: Option, } -pub struct RoomGuard { - data: T, - _guard: OwnedMutexGuard<()>, - _not_send: PhantomData>, -} - -impl Deref for RoomGuard { - type Target = T; - - fn deref(&self) -> &T { - &self.data - } -} - -impl DerefMut for RoomGuard { - fn deref_mut(&mut self) -> &mut T { - &mut self.data +impl Database { + pub async fn new(options: ConnectOptions) -> Result { + Ok(Self { + options: options.clone(), + pool: sea_orm::Database::connect(options).await?, + rooms: DashMap::with_capacity(16384), + #[cfg(test)] + background: None, + #[cfg(test)] + runtime: None, + }) } -} -pub trait BeginTransaction: Send + Sync { - type Database: sqlx::Database; + pub async fn migrate( + &self, + migrations_path: &Path, + ignore_checksum_mismatch: bool, + ) -> anyhow::Result> { + let migrations = MigrationSource::resolve(migrations_path) + .await + .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?; - fn begin_transaction(&self) -> BoxFuture>>; -} + let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?; -// In Postgres, serializable transactions are opt-in -impl BeginTransaction for Db { - type Database = sqlx::Postgres; + connection.ensure_migrations_table().await?; + let applied_migrations: HashMap<_, _> = connection + .list_applied_migrations() + .await? + .into_iter() + .map(|m| (m.version, m)) + .collect(); - fn begin_transaction(&self) -> BoxFuture>> { - async move { - let mut tx = self.pool.begin().await?; - sqlx::Executor::execute(&mut tx, "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;") - .await?; - Ok(tx) + let mut new_migrations = Vec::new(); + for migration in migrations { + match applied_migrations.get(&migration.version) { + Some(applied_migration) => { + if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch + { + Err(anyhow!( + "checksum mismatch for applied migration {}", + migration.description + ))?; + } + } + None => { + let elapsed = connection.apply(&migration).await?; + new_migrations.push((migration, elapsed)); + } + } } - .boxed() - } -} -// In Sqlite, transactions are inherently serializable. -#[cfg(test)] -impl BeginTransaction for Db { - type Database = sqlx::Sqlite; - - fn begin_transaction(&self) -> BoxFuture>> { - async move { Ok(self.pool.begin().await?) }.boxed() - } -} - -pub trait BuildQuery { - fn build_query(&self, query: &T) -> (String, sea_query_binder::SqlxValues); -} - -impl BuildQuery for Db { - fn build_query(&self, query: &T) -> (String, sea_query_binder::SqlxValues) { - query.build_sqlx(sea_query::PostgresQueryBuilder) - } -} - -#[cfg(test)] -impl BuildQuery for Db { - fn build_query(&self, query: &T) -> (String, sea_query_binder::SqlxValues) { - query.build_sqlx(sea_query::SqliteQueryBuilder) - } -} - -pub trait RowsAffected { - fn rows_affected(&self) -> u64; -} - -#[cfg(test)] -impl RowsAffected for sqlx::sqlite::SqliteQueryResult { - fn rows_affected(&self) -> u64 { - self.rows_affected() - } -} - -impl RowsAffected for sqlx::postgres::PgQueryResult { - fn rows_affected(&self) -> u64 { - self.rows_affected() - } -} - -#[cfg(test)] -impl Db { - pub async fn new(url: &str, max_connections: u32) -> Result { - use std::str::FromStr as _; - let options = sqlx::sqlite::SqliteConnectOptions::from_str(url) - .unwrap() - .create_if_missing(true) - .shared_cache(true); - let pool = sqlx::sqlite::SqlitePoolOptions::new() - .min_connections(2) - .max_connections(max_connections) - .connect_with(options) - .await?; - Ok(Self { - pool, - rooms: Default::default(), - background: None, - runtime: None, - }) - } - - pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { - self.transact(|tx| async { - let mut tx = tx; - let query = " - SELECT users.* - FROM users - WHERE users.id IN (SELECT value from json_each($1)) - "; - Ok(sqlx::query_as(query) - .bind(&serde_json::json!(ids)) - .fetch_all(&mut tx) - .await?) - }) - .await + Ok(new_migrations) } - pub async fn get_user_metrics_id(&self, id: UserId) -> Result { - self.transact(|mut tx| async move { - let query = " - SELECT metrics_id - FROM users - WHERE id = $1 - "; - Ok(sqlx::query_scalar(query) - .bind(id) - .fetch_one(&mut tx) - .await?) - }) - .await - } + // users pub async fn create_user( &self, @@ -187,26 +108,28 @@ impl Db { admin: bool, params: NewUserParams, ) -> Result { - self.transact(|mut tx| async { - let query = " - INSERT INTO users (email_address, github_login, github_user_id, admin, metrics_id) - VALUES ($1, $2, $3, $4, $5) - ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login - RETURNING id, metrics_id - "; + self.transact(|tx| async { + let user = user::Entity::insert(user::ActiveModel { + email_address: ActiveValue::set(Some(email_address.into())), + github_login: ActiveValue::set(params.github_login.clone()), + github_user_id: ActiveValue::set(Some(params.github_user_id)), + admin: ActiveValue::set(admin), + metrics_id: ActiveValue::set(Uuid::new_v4()), + ..Default::default() + }) + .on_conflict( + OnConflict::column(user::Column::GithubLogin) + .update_column(user::Column::GithubLogin) + .to_owned(), + ) + .exec_with_returning(&tx) + .await?; - let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query) - .bind(email_address) - .bind(¶ms.github_login) - .bind(¶ms.github_user_id) - .bind(admin) - .bind(Uuid::new_v4().to_string()) - .fetch_one(&mut tx) - .await?; tx.commit().await?; + Ok(NewUserResult { - user_id, - metrics_id, + user_id: user.id, + metrics_id: user.metrics_id.to_string(), signup_device_id: None, inviting_user_id: None, }) @@ -214,481 +137,418 @@ impl Db { .await } - pub async fn fuzzy_search_users(&self, _name_query: &str, _limit: u32) -> Result> { - unimplemented!() - } - - pub async fn create_user_from_invite( - &self, - _invite: &Invite, - _user: NewUserParams, - ) -> Result> { - unimplemented!() - } - - pub async fn create_signup(&self, _signup: Signup) -> Result<()> { - unimplemented!() - } - - pub async fn create_invite_from_code( - &self, - _code: &str, - _email_address: &str, - _device_id: Option<&str>, - ) -> Result { - unimplemented!() - } - - pub async fn record_sent_invites(&self, _invites: &[Invite]) -> Result<()> { - unimplemented!() - } -} - -impl Db { - pub async fn new(url: &str, max_connections: u32) -> Result { - let pool = sqlx::postgres::PgPoolOptions::new() - .max_connections(max_connections) - .connect(url) - .await?; - Ok(Self { - pool, - rooms: DashMap::with_capacity(16384), - #[cfg(test)] - background: None, - #[cfg(test)] - runtime: None, - }) - } - - #[cfg(test)] - pub fn teardown(&self, url: &str) { - self.runtime.as_ref().unwrap().block_on(async { - use util::ResultExt; - let query = " - SELECT pg_terminate_backend(pg_stat_activity.pid) - FROM pg_stat_activity - WHERE pg_stat_activity.datname = current_database() AND pid <> pg_backend_pid(); - "; - sqlx::query(query).execute(&self.pool).await.log_err(); - self.pool.close().await; - ::drop_database(url) - .await - .log_err(); - }) + pub async fn get_user_by_id(&self, id: UserId) -> Result> { + self.transact(|tx| async move { Ok(user::Entity::find_by_id(id).one(&tx).await?) }) + .await } - pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result> { + pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { self.transact(|tx| async { - let mut tx = tx; - let like_string = Self::fuzzy_like_string(name_query); - let query = " - SELECT users.* - FROM users - WHERE github_login ILIKE $1 - ORDER BY github_login <-> $2 - LIMIT $3 - "; - Ok(sqlx::query_as(query) - .bind(like_string) - .bind(name_query) - .bind(limit as i32) - .fetch_all(&mut tx) + let tx = tx; + Ok(user::Entity::find() + .filter(user::Column::Id.is_in(ids.iter().copied())) + .all(&tx) .await?) }) .await } - pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { - let ids = ids.iter().map(|id| id.0).collect::>(); + pub async fn get_user_by_github_account( + &self, + github_login: &str, + github_user_id: Option, + ) -> Result> { self.transact(|tx| async { - let mut tx = tx; - let query = " - SELECT users.* - FROM users - WHERE users.id = ANY ($1) - "; - Ok(sqlx::query_as(query).bind(&ids).fetch_all(&mut tx).await?) + let tx = tx; + if let Some(github_user_id) = github_user_id { + if let Some(user_by_github_user_id) = user::Entity::find() + .filter(user::Column::GithubUserId.eq(github_user_id)) + .one(&tx) + .await? + { + let mut user_by_github_user_id = user_by_github_user_id.into_active_model(); + user_by_github_user_id.github_login = ActiveValue::set(github_login.into()); + Ok(Some(user_by_github_user_id.update(&tx).await?)) + } else if let Some(user_by_github_login) = user::Entity::find() + .filter(user::Column::GithubLogin.eq(github_login)) + .one(&tx) + .await? + { + let mut user_by_github_login = user_by_github_login.into_active_model(); + user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id)); + Ok(Some(user_by_github_login.update(&tx).await?)) + } else { + Ok(None) + } + } else { + Ok(user::Entity::find() + .filter(user::Column::GithubLogin.eq(github_login)) + .one(&tx) + .await?) + } }) .await } - pub async fn get_user_metrics_id(&self, id: UserId) -> Result { - self.transact(|mut tx| async move { - let query = " - SELECT metrics_id::text - FROM users - WHERE id = $1 - "; - Ok(sqlx::query_scalar(query) - .bind(id) - .fetch_one(&mut tx) + pub async fn get_all_users(&self, page: u32, limit: u32) -> Result> { + self.transact(|tx| async move { + Ok(user::Entity::find() + .order_by_asc(user::Column::GithubLogin) + .limit(limit as u64) + .offset(page as u64 * limit as u64) + .all(&tx) .await?) }) .await } - pub async fn create_user( + pub async fn get_users_with_no_invites( &self, - email_address: &str, - admin: bool, - params: NewUserParams, - ) -> Result { - self.transact(|mut tx| async { - let query = " - INSERT INTO users (email_address, github_login, github_user_id, admin) - VALUES ($1, $2, $3, $4) - ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login - RETURNING id, metrics_id::text - "; - - let (user_id, metrics_id): (UserId, String) = sqlx::query_as(query) - .bind(email_address) - .bind(¶ms.github_login) - .bind(params.github_user_id) - .bind(admin) - .fetch_one(&mut tx) - .await?; - tx.commit().await?; - - Ok(NewUserResult { - user_id, - metrics_id, - signup_device_id: None, - inviting_user_id: None, - }) + invited_by_another_user: bool, + ) -> Result> { + self.transact(|tx| async move { + Ok(user::Entity::find() + .filter( + user::Column::InviteCount + .eq(0) + .and(if invited_by_another_user { + user::Column::InviterId.is_not_null() + } else { + user::Column::InviterId.is_null() + }), + ) + .all(&tx) + .await?) }) .await } - pub async fn create_user_from_invite( - &self, - invite: &Invite, - user: NewUserParams, - ) -> Result> { - self.transact(|mut tx| async { - let (signup_id, existing_user_id, inviting_user_id, signup_device_id): ( - i32, - Option, - Option, - Option, - ) = sqlx::query_as( - " - SELECT id, user_id, inviting_user_id, device_id - FROM signups - WHERE - email_address = $1 AND - email_confirmation_code = $2 - ", - ) - .bind(&invite.email_address) - .bind(&invite.email_confirmation_code) - .fetch_optional(&mut tx) - .await? - .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?; - - if existing_user_id.is_some() { - return Ok(None); - } - - let (user_id, metrics_id): (UserId, String) = sqlx::query_as( - " - INSERT INTO users - (email_address, github_login, github_user_id, admin, invite_count, invite_code) - VALUES - ($1, $2, $3, FALSE, $4, $5) - ON CONFLICT (github_login) DO UPDATE SET - email_address = excluded.email_address, - github_user_id = excluded.github_user_id, - admin = excluded.admin - RETURNING id, metrics_id::text - ", - ) - .bind(&invite.email_address) - .bind(&user.github_login) - .bind(&user.github_user_id) - .bind(&user.invite_count) - .bind(random_invite_code()) - .fetch_one(&mut tx) - .await?; - - sqlx::query( - " - UPDATE signups - SET user_id = $1 - WHERE id = $2 - ", - ) - .bind(&user_id) - .bind(&signup_id) - .execute(&mut tx) - .await?; - - if let Some(inviting_user_id) = inviting_user_id { - let id: Option = sqlx::query_scalar( - " - UPDATE users - SET invite_count = invite_count - 1 - WHERE id = $1 AND invite_count > 0 - RETURNING id - ", - ) - .bind(&inviting_user_id) - .fetch_optional(&mut tx) - .await?; + pub async fn get_user_metrics_id(&self, id: UserId) -> Result { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryAs { + MetricsId, + } - if id.is_none() { - Err(Error::Http( - StatusCode::UNAUTHORIZED, - "no invites remaining".to_string(), - ))?; - } + self.transact(|tx| async move { + let metrics_id: Uuid = user::Entity::find_by_id(id) + .select_only() + .column(user::Column::MetricsId) + .into_values::<_, QueryAs>() + .one(&tx) + .await? + .ok_or_else(|| anyhow!("could not find user"))?; + Ok(metrics_id.to_string()) + }) + .await + } - sqlx::query( - " - INSERT INTO contacts - (user_id_a, user_id_b, a_to_b, should_notify, accepted) - VALUES - ($1, $2, TRUE, TRUE, TRUE) - ON CONFLICT DO NOTHING - ", - ) - .bind(inviting_user_id) - .bind(user_id) - .execute(&mut tx) + pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> { + self.transact(|tx| async move { + user::Entity::update_many() + .filter(user::Column::Id.eq(id)) + .col_expr(user::Column::Admin, is_admin.into()) + .exec(&tx) .await?; - } - tx.commit().await?; - Ok(Some(NewUserResult { - user_id, - metrics_id, - inviting_user_id, - signup_device_id, - })) + Ok(()) }) .await } - pub async fn create_signup(&self, signup: Signup) -> Result<()> { - self.transact(|mut tx| async { - sqlx::query( - " - INSERT INTO signups - ( - email_address, - email_confirmation_code, - email_confirmation_sent, - platform_linux, - platform_mac, - platform_windows, - platform_unknown, - editor_features, - programming_languages, - device_id - ) - VALUES - ($1, $2, FALSE, $3, $4, $5, FALSE, $6, $7, $8) - RETURNING id - ", - ) - .bind(&signup.email_address) - .bind(&random_email_confirmation_code()) - .bind(&signup.platform_linux) - .bind(&signup.platform_mac) - .bind(&signup.platform_windows) - .bind(&signup.editor_features) - .bind(&signup.programming_languages) - .bind(&signup.device_id) - .execute(&mut tx) - .await?; + pub async fn destroy_user(&self, id: UserId) -> Result<()> { + self.transact(|tx| async move { + access_token::Entity::delete_many() + .filter(access_token::Column::UserId.eq(id)) + .exec(&tx) + .await?; + user::Entity::delete_by_id(id).exec(&tx).await?; tx.commit().await?; Ok(()) }) .await } - pub async fn create_invite_from_code( - &self, - code: &str, - email_address: &str, - device_id: Option<&str>, - ) -> Result { - self.transact(|mut tx| async { - let existing_user: Option = sqlx::query_scalar( - " - SELECT id - FROM users - WHERE email_address = $1 - ", - ) - .bind(email_address) - .fetch_optional(&mut tx) - .await?; - if existing_user.is_some() { - Err(anyhow!("email address is already in use"))?; - } + // contacts - let row: Option<(UserId, i32)> = sqlx::query_as( - " - SELECT id, invite_count - FROM users - WHERE invite_code = $1 - ", - ) - .bind(code) - .fetch_optional(&mut tx) - .await?; + pub async fn get_contacts(&self, user_id: UserId) -> Result> { + #[derive(Debug, FromQueryResult)] + struct ContactWithUserBusyStatuses { + user_id_a: UserId, + user_id_b: UserId, + a_to_b: bool, + accepted: bool, + should_notify: bool, + user_a_busy: bool, + user_b_busy: bool, + } - let (inviter_id, invite_count) = match row { - Some(row) => row, - None => Err(Error::Http( - StatusCode::NOT_FOUND, - "invite code not found".to_string(), - ))?, - }; + self.transact(|tx| async move { + let user_a_participant = Alias::new("user_a_participant"); + let user_b_participant = Alias::new("user_b_participant"); + let mut db_contacts = contact::Entity::find() + .column_as( + Expr::tbl(user_a_participant.clone(), room_participant::Column::Id) + .is_not_null(), + "user_a_busy", + ) + .column_as( + Expr::tbl(user_b_participant.clone(), room_participant::Column::Id) + .is_not_null(), + "user_b_busy", + ) + .filter( + contact::Column::UserIdA + .eq(user_id) + .or(contact::Column::UserIdB.eq(user_id)), + ) + .join_as( + JoinType::LeftJoin, + contact::Relation::UserARoomParticipant.def(), + user_a_participant, + ) + .join_as( + JoinType::LeftJoin, + contact::Relation::UserBRoomParticipant.def(), + user_b_participant, + ) + .into_model::() + .stream(&tx) + .await?; - if invite_count == 0 { - Err(Error::Http( - StatusCode::UNAUTHORIZED, - "no invites remaining".to_string(), - ))?; + let mut contacts = Vec::new(); + while let Some(db_contact) = db_contacts.next().await { + let db_contact = db_contact?; + if db_contact.user_id_a == user_id { + if db_contact.accepted { + contacts.push(Contact::Accepted { + user_id: db_contact.user_id_b, + should_notify: db_contact.should_notify && db_contact.a_to_b, + busy: db_contact.user_b_busy, + }); + } else if db_contact.a_to_b { + contacts.push(Contact::Outgoing { + user_id: db_contact.user_id_b, + }) + } else { + contacts.push(Contact::Incoming { + user_id: db_contact.user_id_b, + should_notify: db_contact.should_notify, + }); + } + } else if db_contact.accepted { + contacts.push(Contact::Accepted { + user_id: db_contact.user_id_a, + should_notify: db_contact.should_notify && !db_contact.a_to_b, + busy: db_contact.user_a_busy, + }); + } else if db_contact.a_to_b { + contacts.push(Contact::Incoming { + user_id: db_contact.user_id_a, + should_notify: db_contact.should_notify, + }); + } else { + contacts.push(Contact::Outgoing { + user_id: db_contact.user_id_a, + }); + } } - let email_confirmation_code: String = sqlx::query_scalar( - " - INSERT INTO signups - ( - email_address, - email_confirmation_code, - email_confirmation_sent, - inviting_user_id, - platform_linux, - platform_mac, - platform_windows, - platform_unknown, - device_id - ) - VALUES - ($1, $2, FALSE, $3, FALSE, FALSE, FALSE, TRUE, $4) - ON CONFLICT (email_address) - DO UPDATE SET - inviting_user_id = excluded.inviting_user_id - RETURNING email_confirmation_code - ", - ) - .bind(&email_address) - .bind(&random_email_confirmation_code()) - .bind(&inviter_id) - .bind(&device_id) - .fetch_one(&mut tx) - .await?; + contacts.sort_unstable_by_key(|contact| contact.user_id()); - tx.commit().await?; + Ok(contacts) + }) + .await + } - Ok(Invite { - email_address: email_address.into(), - email_confirmation_code, - }) + pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result { + self.transact(|tx| async move { + let (id_a, id_b) = if user_id_1 < user_id_2 { + (user_id_1, user_id_2) + } else { + (user_id_2, user_id_1) + }; + + Ok(contact::Entity::find() + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)) + .and(contact::Column::Accepted.eq(true)), + ) + .one(&tx) + .await? + .is_some()) }) .await } - pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> { - self.transact(|mut tx| async { - let emails = invites - .iter() - .map(|s| s.email_address.as_str()) - .collect::>(); - sqlx::query( - " - UPDATE signups - SET email_confirmation_sent = TRUE - WHERE email_address = ANY ($1) - ", + pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { + self.transact(|tx| async move { + let (id_a, id_b, a_to_b) = if sender_id < receiver_id { + (sender_id, receiver_id, true) + } else { + (receiver_id, sender_id, false) + }; + + let rows_affected = contact::Entity::insert(contact::ActiveModel { + user_id_a: ActiveValue::set(id_a), + user_id_b: ActiveValue::set(id_b), + a_to_b: ActiveValue::set(a_to_b), + accepted: ActiveValue::set(false), + should_notify: ActiveValue::set(true), + ..Default::default() + }) + .on_conflict( + OnConflict::columns([contact::Column::UserIdA, contact::Column::UserIdB]) + .values([ + (contact::Column::Accepted, true.into()), + (contact::Column::ShouldNotify, false.into()), + ]) + .action_and_where( + contact::Column::Accepted.eq(false).and( + contact::Column::AToB + .eq(a_to_b) + .and(contact::Column::UserIdA.eq(id_b)) + .or(contact::Column::AToB + .ne(a_to_b) + .and(contact::Column::UserIdA.eq(id_a))), + ), + ) + .to_owned(), ) - .bind(&emails) - .execute(&mut tx) + .exec_without_returning(&tx) .await?; - tx.commit().await?; - Ok(()) + + if rows_affected == 1 { + tx.commit().await?; + Ok(()) + } else { + Err(anyhow!("contact already requested"))? + } }) .await } -} -impl Db -where - Self: BeginTransaction + BuildQuery, - D: sqlx::Database + sqlx::migrate::MigrateDatabase, - D::Connection: sqlx::migrate::Migrate, - for<'a> >::Arguments: sqlx::IntoArguments<'a, D>, - for<'a> sea_query_binder::SqlxValues: sqlx::IntoArguments<'a, D>, - for<'a> &'a mut D::Connection: sqlx::Executor<'a, Database = D>, - for<'a, 'b> &'b mut sqlx::Transaction<'a, D>: sqlx::Executor<'b, Database = D>, - D::QueryResult: RowsAffected, - String: sqlx::Type, - i32: sqlx::Type, - i64: sqlx::Type, - bool: sqlx::Type, - str: sqlx::Type, - Uuid: sqlx::Type, - sqlx::types::Json: sqlx::Type, - OffsetDateTime: sqlx::Type, - PrimitiveDateTime: sqlx::Type, - usize: sqlx::ColumnIndex, - for<'a> &'a str: sqlx::ColumnIndex, - for<'a> &'a str: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> String: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> Option: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> Option<&'a str>: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> i32: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> i64: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> bool: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> Uuid: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> Option: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> sqlx::types::JsonValue: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> OffsetDateTime: sqlx::Encode<'a, D> + sqlx::Decode<'a, D>, - for<'a> PrimitiveDateTime: sqlx::Decode<'a, D> + sqlx::Decode<'a, D>, -{ - pub async fn migrate( - &self, - migrations_path: &Path, - ignore_checksum_mismatch: bool, - ) -> anyhow::Result> { - let migrations = MigrationSource::resolve(migrations_path) - .await - .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?; + pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> { + self.transact(|tx| async move { + let (id_a, id_b) = if responder_id < requester_id { + (responder_id, requester_id) + } else { + (requester_id, responder_id) + }; - let mut conn = self.pool.acquire().await?; + let result = contact::Entity::delete_many() + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)), + ) + .exec(&tx) + .await?; - conn.ensure_migrations_table().await?; - let applied_migrations: HashMap<_, _> = conn - .list_applied_migrations() - .await? - .into_iter() - .map(|m| (m.version, m)) - .collect(); + if result.rows_affected == 1 { + tx.commit().await?; + Ok(()) + } else { + Err(anyhow!("no such contact"))? + } + }) + .await + } - let mut new_migrations = Vec::new(); - for migration in migrations { - match applied_migrations.get(&migration.version) { - Some(applied_migration) => { - if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch - { - Err(anyhow!( - "checksum mismatch for applied migration {}", - migration.description - ))?; - } - } - None => { - let elapsed = conn.apply(&migration).await?; - new_migrations.push((migration, elapsed)); - } + pub async fn dismiss_contact_notification( + &self, + user_id: UserId, + contact_user_id: UserId, + ) -> Result<()> { + self.transact(|tx| async move { + let (id_a, id_b, a_to_b) = if user_id < contact_user_id { + (user_id, contact_user_id, true) + } else { + (contact_user_id, user_id, false) + }; + + let result = contact::Entity::update_many() + .set(contact::ActiveModel { + should_notify: ActiveValue::set(false), + ..Default::default() + }) + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)) + .and( + contact::Column::AToB + .eq(a_to_b) + .and(contact::Column::Accepted.eq(true)) + .or(contact::Column::AToB + .ne(a_to_b) + .and(contact::Column::Accepted.eq(false))), + ), + ) + .exec(&tx) + .await?; + if result.rows_affected == 0 { + Err(anyhow!("no such contact request"))? + } else { + tx.commit().await?; + Ok(()) } - } + }) + .await + } - Ok(new_migrations) + pub async fn respond_to_contact_request( + &self, + responder_id: UserId, + requester_id: UserId, + accept: bool, + ) -> Result<()> { + self.transact(|tx| async move { + let (id_a, id_b, a_to_b) = if responder_id < requester_id { + (responder_id, requester_id, false) + } else { + (requester_id, responder_id, true) + }; + let rows_affected = if accept { + let result = contact::Entity::update_many() + .set(contact::ActiveModel { + accepted: ActiveValue::set(true), + should_notify: ActiveValue::set(true), + ..Default::default() + }) + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)) + .and(contact::Column::AToB.eq(a_to_b)), + ) + .exec(&tx) + .await?; + result.rows_affected + } else { + let result = contact::Entity::delete_many() + .filter( + contact::Column::UserIdA + .eq(id_a) + .and(contact::Column::UserIdB.eq(id_b)) + .and(contact::Column::AToB.eq(a_to_b)) + .and(contact::Column::Accepted.eq(false)), + ) + .exec(&tx) + .await?; + + result.rows_affected + }; + + if rows_affected == 1 { + tx.commit().await?; + Ok(()) + } else { + Err(anyhow!("no such contact request"))? + } + }) + .await } pub fn fuzzy_like_string(string: &str) -> String { @@ -703,163 +563,58 @@ where result } - // users - - pub async fn get_all_users(&self, page: u32, limit: u32) -> Result> { - self.transact(|tx| async { - let mut tx = tx; - let query = "SELECT * FROM users ORDER BY github_login ASC LIMIT $1 OFFSET $2"; - Ok(sqlx::query_as(query) - .bind(limit as i32) - .bind((page * limit) as i32) - .fetch_all(&mut tx) - .await?) - }) - .await - } - - pub async fn get_user_by_id(&self, id: UserId) -> Result> { + pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result> { self.transact(|tx| async { - let mut tx = tx; + let tx = tx; + let like_string = Self::fuzzy_like_string(name_query); let query = " SELECT users.* FROM users - WHERE id = $1 - LIMIT 1 + WHERE github_login ILIKE $1 + ORDER BY github_login <-> $2 + LIMIT $3 "; - Ok(sqlx::query_as(query) - .bind(&id) - .fetch_optional(&mut tx) + + Ok(user::Entity::find() + .from_raw_sql(Statement::from_sql_and_values( + self.pool.get_database_backend(), + query.into(), + vec![like_string.into(), name_query.into(), limit.into()], + )) + .all(&tx) .await?) }) .await } - pub async fn get_users_with_no_invites( - &self, - invited_by_another_user: bool, - ) -> Result> { - self.transact(|tx| async { - let mut tx = tx; - let query = format!( - " - SELECT users.* - FROM users - WHERE invite_count = 0 - AND inviter_id IS{} NULL - ", - if invited_by_another_user { " NOT" } else { "" } - ); - - Ok(sqlx::query_as(&query).fetch_all(&mut tx).await?) - }) - .await - } + // signups - pub async fn get_user_by_github_account( - &self, - github_login: &str, - github_user_id: Option, - ) -> Result> { + pub async fn create_signup(&self, signup: NewSignup) -> Result<()> { self.transact(|tx| async { - let mut tx = tx; - if let Some(github_user_id) = github_user_id { - let mut user = sqlx::query_as::<_, User>( - " - UPDATE users - SET github_login = $1 - WHERE github_user_id = $2 - RETURNING * - ", - ) - .bind(github_login) - .bind(github_user_id) - .fetch_optional(&mut tx) - .await?; - - if user.is_none() { - user = sqlx::query_as::<_, User>( - " - UPDATE users - SET github_user_id = $1 - WHERE github_login = $2 - RETURNING * - ", - ) - .bind(github_user_id) - .bind(github_login) - .fetch_optional(&mut tx) - .await?; - } - - Ok(user) - } else { - let user = sqlx::query_as( - " - SELECT * FROM users - WHERE github_login = $1 - LIMIT 1 - ", - ) - .bind(github_login) - .fetch_optional(&mut tx) - .await?; - Ok(user) + signup::ActiveModel { + email_address: ActiveValue::set(signup.email_address.clone()), + email_confirmation_code: ActiveValue::set(random_email_confirmation_code()), + email_confirmation_sent: ActiveValue::set(false), + platform_mac: ActiveValue::set(signup.platform_mac), + platform_windows: ActiveValue::set(signup.platform_windows), + platform_linux: ActiveValue::set(signup.platform_linux), + platform_unknown: ActiveValue::set(false), + editor_features: ActiveValue::set(Some(signup.editor_features.clone())), + programming_languages: ActiveValue::set(Some(signup.programming_languages.clone())), + device_id: ActiveValue::set(signup.device_id.clone()), + ..Default::default() } - }) - .await - } - - pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> { - self.transact(|mut tx| async { - let query = "UPDATE users SET admin = $1 WHERE id = $2"; - sqlx::query(query) - .bind(is_admin) - .bind(id.0) - .execute(&mut tx) - .await?; - tx.commit().await?; - Ok(()) - }) - .await - } - - pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> { - self.transact(|mut tx| async move { - let query = "UPDATE users SET connected_once = $1 WHERE id = $2"; - sqlx::query(query) - .bind(connected_once) - .bind(id.0) - .execute(&mut tx) - .await?; - tx.commit().await?; - Ok(()) - }) - .await - } - - pub async fn destroy_user(&self, id: UserId) -> Result<()> { - self.transact(|mut tx| async move { - let query = "DELETE FROM access_tokens WHERE user_id = $1;"; - sqlx::query(query) - .bind(id.0) - .execute(&mut tx) - .await - .map(drop)?; - let query = "DELETE FROM users WHERE id = $1;"; - sqlx::query(query).bind(id.0).execute(&mut tx).await?; + .insert(&tx) + .await?; tx.commit().await?; Ok(()) }) .await } - // signups - pub async fn get_waitlist_summary(&self) -> Result { - self.transact(|mut tx| async move { - Ok(sqlx::query_as( - " + self.transact(|tx| async move { + let query = " SELECT COUNT(*) as count, COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count, @@ -872,1671 +627,435 @@ where WHERE NOT email_confirmation_sent ) AS unsent - ", + "; + Ok( + WaitlistSummary::find_by_statement(Statement::from_sql_and_values( + self.pool.get_database_backend(), + query.into(), + vec![], + )) + .one(&tx) + .await? + .ok_or_else(|| anyhow!("invalid result"))?, ) - .fetch_one(&mut tx) - .await?) }) .await } - pub async fn get_unsent_invites(&self, count: usize) -> Result> { - self.transact(|mut tx| async move { - Ok(sqlx::query_as( - " - SELECT - email_address, email_confirmation_code - FROM signups - WHERE - NOT email_confirmation_sent AND - (platform_mac OR platform_unknown) - LIMIT $1 - ", - ) - .bind(count as i32) - .fetch_all(&mut tx) - .await?) + pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> { + let emails = invites + .iter() + .map(|s| s.email_address.as_str()) + .collect::>(); + self.transact(|tx| async { + signup::Entity::update_many() + .filter(signup::Column::EmailAddress.is_in(emails.iter().copied())) + .col_expr(signup::Column::EmailConfirmationSent, true.into()) + .exec(&tx) + .await?; + tx.commit().await?; + Ok(()) }) .await } - // invite codes - - pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> { - self.transact(|mut tx| async move { - if count > 0 { - sqlx::query( - " - UPDATE users - SET invite_code = $1 - WHERE id = $2 AND invite_code IS NULL - ", - ) - .bind(random_invite_code()) - .bind(id) - .execute(&mut tx) - .await?; - } - - sqlx::query( - " - UPDATE users - SET invite_count = $1 - WHERE id = $2 - ", - ) - .bind(count as i32) - .bind(id) - .execute(&mut tx) - .await?; - tx.commit().await?; - Ok(()) - }) - .await - } - - pub async fn get_invite_code_for_user(&self, id: UserId) -> Result> { - self.transact(|mut tx| async move { - let result: Option<(String, i32)> = sqlx::query_as( - " - SELECT invite_code, invite_count - FROM users - WHERE id = $1 AND invite_code IS NOT NULL - ", - ) - .bind(id) - .fetch_optional(&mut tx) - .await?; - if let Some((code, count)) = result { - Ok(Some((code, count.try_into().map_err(anyhow::Error::new)?))) - } else { - Ok(None) - } - }) - .await - } - - pub async fn get_user_for_invite_code(&self, code: &str) -> Result { - self.transact(|tx| async { - let mut tx = tx; - sqlx::query_as( - " - SELECT * - FROM users - WHERE invite_code = $1 - ", - ) - .bind(code) - .fetch_optional(&mut tx) - .await? - .ok_or_else(|| { - Error::Http( - StatusCode::NOT_FOUND, - "that invite code does not exist".to_string(), + pub async fn get_unsent_invites(&self, count: usize) -> Result> { + self.transact(|tx| async move { + Ok(signup::Entity::find() + .select_only() + .column(signup::Column::EmailAddress) + .column(signup::Column::EmailConfirmationCode) + .filter( + signup::Column::EmailConfirmationSent.eq(false).and( + signup::Column::PlatformMac + .eq(true) + .or(signup::Column::PlatformUnknown.eq(true)), + ), ) - }) + .limit(count as u64) + .into_model() + .all(&tx) + .await?) }) .await } - async fn commit_room_transaction<'a, T>( - &'a self, - room_id: RoomId, - tx: sqlx::Transaction<'static, D>, - data: T, - ) -> Result> { - let lock = self.rooms.entry(room_id).or_default().clone(); - let _guard = lock.lock_owned().await; - tx.commit().await?; - Ok(RoomGuard { - data, - _guard, - _not_send: PhantomData, - }) - } - - pub async fn create_room( - &self, - user_id: UserId, - connection_id: ConnectionId, - live_kit_room: &str, - ) -> Result> { - self.transact(|mut tx| async move { - let room_id = sqlx::query_scalar( - " - INSERT INTO rooms (live_kit_room) - VALUES ($1) - RETURNING id - ", - ) - .bind(&live_kit_room) - .fetch_one(&mut tx) - .await - .map(RoomId)?; - - sqlx::query( - " - INSERT INTO room_participants (room_id, user_id, answering_connection_id, calling_user_id, calling_connection_id) - VALUES ($1, $2, $3, $4, $5) - ", - ) - .bind(room_id) - .bind(user_id) - .bind(connection_id.0 as i32) - .bind(user_id) - .bind(connection_id.0 as i32) - .execute(&mut tx) - .await?; - - let room = self.get_room(room_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, room).await - }).await - } + // invite codes - pub async fn call( + pub async fn create_invite_from_code( &self, - room_id: RoomId, - calling_user_id: UserId, - calling_connection_id: ConnectionId, - called_user_id: UserId, - initial_project_id: Option, - ) -> Result> { - self.transact(|mut tx| async move { - sqlx::query( - " - INSERT INTO room_participants ( - room_id, - user_id, - calling_user_id, - calling_connection_id, - initial_project_id - ) - VALUES ($1, $2, $3, $4, $5) - ", - ) - .bind(room_id) - .bind(called_user_id) - .bind(calling_user_id) - .bind(calling_connection_id.0 as i32) - .bind(initial_project_id) - .execute(&mut tx) - .await?; + code: &str, + email_address: &str, + device_id: Option<&str>, + ) -> Result { + self.transact(|tx| async move { + let existing_user = user::Entity::find() + .filter(user::Column::EmailAddress.eq(email_address)) + .one(&tx) + .await?; - let room = self.get_room(room_id, &mut tx).await?; - let incoming_call = Self::build_incoming_call(&room, called_user_id) - .ok_or_else(|| anyhow!("failed to build incoming call"))?; - self.commit_room_transaction(room_id, tx, (room, incoming_call)) - .await - }) - .await - } + if existing_user.is_some() { + Err(anyhow!("email address is already in use"))?; + } - pub async fn incoming_call_for_user( - &self, - user_id: UserId, - ) -> Result> { - self.transact(|mut tx| async move { - let room_id = sqlx::query_scalar::<_, RoomId>( - " - SELECT room_id - FROM room_participants - WHERE user_id = $1 AND answering_connection_id IS NULL - ", - ) - .bind(user_id) - .fetch_optional(&mut tx) - .await?; + let inviter = match user::Entity::find() + .filter(user::Column::InviteCode.eq(code)) + .one(&tx) + .await? + { + Some(inviter) => inviter, + None => { + return Err(Error::Http( + StatusCode::NOT_FOUND, + "invite code not found".to_string(), + ))? + } + }; - if let Some(room_id) = room_id { - let room = self.get_room(room_id, &mut tx).await?; - Ok(Self::build_incoming_call(&room, user_id)) - } else { - Ok(None) + if inviter.invite_count == 0 { + Err(Error::Http( + StatusCode::UNAUTHORIZED, + "no invites remaining".to_string(), + ))?; } - }) - .await - } - - fn build_incoming_call( - room: &proto::Room, - called_user_id: UserId, - ) -> Option { - let pending_participant = room - .pending_participants - .iter() - .find(|participant| participant.user_id == called_user_id.to_proto())?; - - Some(proto::IncomingCall { - room_id: room.id, - calling_user_id: pending_participant.calling_user_id, - participant_user_ids: room - .participants - .iter() - .map(|participant| participant.user_id) - .collect(), - initial_project: room.participants.iter().find_map(|participant| { - let initial_project_id = pending_participant.initial_project_id?; - participant - .projects - .iter() - .find(|project| project.id == initial_project_id) - .cloned() - }), - }) - } - pub async fn call_failed( - &self, - room_id: RoomId, - called_user_id: UserId, - ) -> Result> { - self.transact(|mut tx| async move { - sqlx::query( - " - DELETE FROM room_participants - WHERE room_id = $1 AND user_id = $2 - ", + let signup = signup::Entity::insert(signup::ActiveModel { + email_address: ActiveValue::set(email_address.into()), + email_confirmation_code: ActiveValue::set(random_email_confirmation_code()), + email_confirmation_sent: ActiveValue::set(false), + inviting_user_id: ActiveValue::set(Some(inviter.id)), + platform_linux: ActiveValue::set(false), + platform_mac: ActiveValue::set(false), + platform_windows: ActiveValue::set(false), + platform_unknown: ActiveValue::set(true), + device_id: ActiveValue::set(device_id.map(|device_id| device_id.into())), + ..Default::default() + }) + .on_conflict( + OnConflict::column(signup::Column::EmailAddress) + .update_column(signup::Column::InvitingUserId) + .to_owned(), ) - .bind(room_id) - .bind(called_user_id) - .execute(&mut tx) + .exec_with_returning(&tx) .await?; + tx.commit().await?; - let room = self.get_room(room_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, room).await + Ok(Invite { + email_address: signup.email_address, + email_confirmation_code: signup.email_confirmation_code, + }) }) .await } - pub async fn decline_call( + pub async fn create_user_from_invite( &self, - expected_room_id: Option, - user_id: UserId, - ) -> Result> { - self.transact(|mut tx| async move { - let room_id = sqlx::query_scalar( - " - DELETE FROM room_participants - WHERE user_id = $1 AND answering_connection_id IS NULL - RETURNING room_id - ", - ) - .bind(user_id) - .fetch_one(&mut tx) - .await?; - if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) { - return Err(anyhow!("declining call on unexpected room"))?; - } - - let room = self.get_room(room_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, room).await - }) - .await - } + invite: &Invite, + user: NewUserParams, + ) -> Result> { + self.transact(|tx| async { + let tx = tx; + let signup = signup::Entity::find() + .filter( + signup::Column::EmailAddress + .eq(invite.email_address.as_str()) + .and( + signup::Column::EmailConfirmationCode + .eq(invite.email_confirmation_code.as_str()), + ), + ) + .one(&tx) + .await? + .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?; - pub async fn cancel_call( - &self, - expected_room_id: Option, - calling_connection_id: ConnectionId, - called_user_id: UserId, - ) -> Result> { - self.transact(|mut tx| async move { - let room_id = sqlx::query_scalar( - " - DELETE FROM room_participants - WHERE user_id = $1 AND calling_connection_id = $2 AND answering_connection_id IS NULL - RETURNING room_id - ", - ) - .bind(called_user_id) - .bind(calling_connection_id.0 as i32) - .fetch_one(&mut tx) - .await?; - if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) { - return Err(anyhow!("canceling call on unexpected room"))?; + if signup.user_id.is_some() { + return Ok(None); } - let room = self.get_room(room_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, room).await - }).await - } - - pub async fn join_room( - &self, - room_id: RoomId, - user_id: UserId, - connection_id: ConnectionId, - ) -> Result> { - self.transact(|mut tx| async move { - sqlx::query( - " - UPDATE room_participants - SET answering_connection_id = $1 - WHERE room_id = $2 AND user_id = $3 - RETURNING 1 - ", - ) - .bind(connection_id.0 as i32) - .bind(room_id) - .bind(user_id) - .fetch_one(&mut tx) - .await?; - - let room = self.get_room(room_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, room).await - }) - .await - } - - pub async fn leave_room( - &self, - connection_id: ConnectionId, - ) -> Result>> { - self.transact(|mut tx| async move { - // Leave room. - let room_id = sqlx::query_scalar::<_, RoomId>( - " - DELETE FROM room_participants - WHERE answering_connection_id = $1 - RETURNING room_id - ", + let user = user::Entity::insert(user::ActiveModel { + email_address: ActiveValue::set(Some(invite.email_address.clone())), + github_login: ActiveValue::set(user.github_login.clone()), + github_user_id: ActiveValue::set(Some(user.github_user_id)), + admin: ActiveValue::set(false), + invite_count: ActiveValue::set(user.invite_count), + invite_code: ActiveValue::set(Some(random_invite_code())), + metrics_id: ActiveValue::set(Uuid::new_v4()), + ..Default::default() + }) + .on_conflict( + OnConflict::column(user::Column::GithubLogin) + .update_columns([ + user::Column::EmailAddress, + user::Column::GithubUserId, + user::Column::Admin, + ]) + .to_owned(), ) - .bind(connection_id.0 as i32) - .fetch_optional(&mut tx) + .exec_with_returning(&tx) .await?; - if let Some(room_id) = room_id { - // Cancel pending calls initiated by the leaving user. - let canceled_calls_to_user_ids: Vec = sqlx::query_scalar( - " - DELETE FROM room_participants - WHERE calling_connection_id = $1 AND answering_connection_id IS NULL - RETURNING user_id - ", - ) - .bind(connection_id.0 as i32) - .fetch_all(&mut tx) - .await?; - - let project_ids = sqlx::query_scalar::<_, ProjectId>( - " - SELECT project_id - FROM project_collaborators - WHERE connection_id = $1 - ", - ) - .bind(connection_id.0 as i32) - .fetch_all(&mut tx) - .await?; - - // Leave projects. - let mut left_projects = HashMap::default(); - if !project_ids.is_empty() { - let mut params = "?,".repeat(project_ids.len()); - params.pop(); - let query = format!( - " - SELECT * - FROM project_collaborators - WHERE project_id IN ({params}) - " - ); - let mut query = sqlx::query_as::<_, ProjectCollaborator>(&query); - for project_id in project_ids { - query = query.bind(project_id); - } + let mut signup = signup.into_active_model(); + signup.user_id = ActiveValue::set(Some(user.id)); + let signup = signup.update(&tx).await?; - let mut project_collaborators = query.fetch(&mut tx); - while let Some(collaborator) = project_collaborators.next().await { - let collaborator = collaborator?; - let left_project = - left_projects - .entry(collaborator.project_id) - .or_insert(LeftProject { - id: collaborator.project_id, - host_user_id: Default::default(), - connection_ids: Default::default(), - host_connection_id: Default::default(), - }); - - let collaborator_connection_id = - ConnectionId(collaborator.connection_id as u32); - if collaborator_connection_id != connection_id { - left_project.connection_ids.push(collaborator_connection_id); - } + if let Some(inviting_user_id) = signup.inviting_user_id { + let result = user::Entity::update_many() + .filter( + user::Column::Id + .eq(inviting_user_id) + .and(user::Column::InviteCount.gt(0)), + ) + .col_expr( + user::Column::InviteCount, + Expr::col(user::Column::InviteCount).sub(1), + ) + .exec(&tx) + .await?; - if collaborator.is_host { - left_project.host_user_id = collaborator.user_id; - left_project.host_connection_id = - ConnectionId(collaborator.connection_id as u32); - } - } + if result.rows_affected == 0 { + Err(Error::Http( + StatusCode::UNAUTHORIZED, + "no invites remaining".to_string(), + ))?; } - sqlx::query( - " - DELETE FROM project_collaborators - WHERE connection_id = $1 - ", - ) - .bind(connection_id.0 as i32) - .execute(&mut tx) - .await?; - // Unshare projects. - sqlx::query( - " - DELETE FROM projects - WHERE room_id = $1 AND host_connection_id = $2 - ", - ) - .bind(room_id) - .bind(connection_id.0 as i32) - .execute(&mut tx) + contact::Entity::insert(contact::ActiveModel { + user_id_a: ActiveValue::set(inviting_user_id), + user_id_b: ActiveValue::set(user.id), + a_to_b: ActiveValue::set(true), + should_notify: ActiveValue::set(true), + accepted: ActiveValue::set(true), + ..Default::default() + }) + .on_conflict(OnConflict::new().do_nothing().to_owned()) + .exec_without_returning(&tx) .await?; - - let room = self.get_room(room_id, &mut tx).await?; - Ok(Some( - self.commit_room_transaction( - room_id, - tx, - LeftRoom { - room, - left_projects, - canceled_calls_to_user_ids, - }, - ) - .await?, - )) - } else { - Ok(None) } + + tx.commit().await?; + Ok(Some(NewUserResult { + user_id: user.id, + metrics_id: user.metrics_id.to_string(), + inviting_user_id: signup.inviting_user_id, + signup_device_id: signup.device_id, + })) }) .await } - pub async fn update_room_participant_location( - &self, - room_id: RoomId, - connection_id: ConnectionId, - location: proto::ParticipantLocation, - ) -> Result> { - self.transact(|tx| async { - let mut tx = tx; - let location_kind; - let location_project_id; - match location - .variant - .as_ref() - .ok_or_else(|| anyhow!("invalid location"))? - { - proto::participant_location::Variant::SharedProject(project) => { - location_kind = 0; - location_project_id = Some(ProjectId::from_proto(project.id)); - } - proto::participant_location::Variant::UnsharedProject(_) => { - location_kind = 1; - location_project_id = None; - } - proto::participant_location::Variant::External(_) => { - location_kind = 2; - location_project_id = None; - } + pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> { + self.transact(|tx| async move { + if count > 0 { + user::Entity::update_many() + .filter( + user::Column::Id + .eq(id) + .and(user::Column::InviteCode.is_null()), + ) + .col_expr(user::Column::InviteCode, random_invite_code().into()) + .exec(&tx) + .await?; } - sqlx::query( - " - UPDATE room_participants - SET location_kind = $1, location_project_id = $2 - WHERE room_id = $3 AND answering_connection_id = $4 - RETURNING 1 - ", - ) - .bind(location_kind) - .bind(location_project_id) - .bind(room_id) - .bind(connection_id.0 as i32) - .fetch_one(&mut tx) - .await?; - - let room = self.get_room(room_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, room).await + user::Entity::update_many() + .filter(user::Column::Id.eq(id)) + .col_expr(user::Column::InviteCount, count.into()) + .exec(&tx) + .await?; + tx.commit().await?; + Ok(()) }) .await } - async fn get_guest_connection_ids( - &self, - project_id: ProjectId, - tx: &mut sqlx::Transaction<'_, D>, - ) -> Result> { - let mut guest_connection_ids = Vec::new(); - let mut db_guest_connection_ids = sqlx::query_scalar::<_, i32>( - " - SELECT connection_id - FROM project_collaborators - WHERE project_id = $1 AND is_host = FALSE - ", - ) - .bind(project_id) - .fetch(tx); - while let Some(connection_id) = db_guest_connection_ids.next().await { - guest_connection_ids.push(ConnectionId(connection_id? as u32)); - } - Ok(guest_connection_ids) - } - - async fn get_room( - &self, - room_id: RoomId, - tx: &mut sqlx::Transaction<'_, D>, - ) -> Result { - let room: Room = sqlx::query_as( - " - SELECT * - FROM rooms - WHERE id = $1 - ", - ) - .bind(room_id) - .fetch_one(&mut *tx) - .await?; - - let mut db_participants = - sqlx::query_as::<_, (UserId, Option, Option, Option, UserId, Option)>( - " - SELECT user_id, answering_connection_id, location_kind, location_project_id, calling_user_id, initial_project_id - FROM room_participants - WHERE room_id = $1 - ", - ) - .bind(room_id) - .fetch(&mut *tx); - - let mut participants = HashMap::default(); - let mut pending_participants = Vec::new(); - while let Some(participant) = db_participants.next().await { - let ( - user_id, - answering_connection_id, - location_kind, - location_project_id, - calling_user_id, - initial_project_id, - ) = participant?; - if let Some(answering_connection_id) = answering_connection_id { - let location = match (location_kind, location_project_id) { - (Some(0), Some(project_id)) => { - Some(proto::participant_location::Variant::SharedProject( - proto::participant_location::SharedProject { - id: project_id.to_proto(), - }, - )) - } - (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject( - Default::default(), - )), - _ => Some(proto::participant_location::Variant::External( - Default::default(), - )), - }; - participants.insert( - answering_connection_id, - proto::Participant { - user_id: user_id.to_proto(), - peer_id: answering_connection_id as u32, - projects: Default::default(), - location: Some(proto::ParticipantLocation { variant: location }), - }, - ); - } else { - pending_participants.push(proto::PendingParticipant { - user_id: user_id.to_proto(), - calling_user_id: calling_user_id.to_proto(), - initial_project_id: initial_project_id.map(|id| id.to_proto()), - }); - } - } - drop(db_participants); - - let mut rows = sqlx::query_as::<_, (i32, ProjectId, Option)>( - " - SELECT host_connection_id, projects.id, worktrees.root_name - FROM projects - LEFT JOIN worktrees ON projects.id = worktrees.project_id - WHERE room_id = $1 - ", - ) - .bind(room_id) - .fetch(&mut *tx); - - while let Some(row) = rows.next().await { - let (connection_id, project_id, worktree_root_name) = row?; - if let Some(participant) = participants.get_mut(&connection_id) { - let project = if let Some(project) = participant - .projects - .iter_mut() - .find(|project| project.id == project_id.to_proto()) - { - project - } else { - participant.projects.push(proto::ParticipantProject { - id: project_id.to_proto(), - worktree_root_names: Default::default(), - }); - participant.projects.last_mut().unwrap() - }; - project.worktree_root_names.extend(worktree_root_name); + pub async fn get_invite_code_for_user(&self, id: UserId) -> Result> { + self.transact(|tx| async move { + match user::Entity::find_by_id(id).one(&tx).await? { + Some(user) if user.invite_code.is_some() => { + Ok(Some((user.invite_code.unwrap(), user.invite_count as u32))) + } + _ => Ok(None), } - } - - Ok(proto::Room { - id: room.id.to_proto(), - live_kit_room: room.live_kit_room, - participants: participants.into_values().collect(), - pending_participants, }) + .await } - // projects - - pub async fn project_count_excluding_admins(&self) -> Result { - self.transact(|mut tx| async move { - Ok(sqlx::query_scalar::<_, i32>( - " - SELECT COUNT(*) - FROM projects, users - WHERE projects.host_user_id = users.id AND users.admin IS FALSE - ", - ) - .fetch_one(&mut tx) - .await? as usize) + pub async fn get_user_for_invite_code(&self, code: &str) -> Result { + self.transact(|tx| async move { + user::Entity::find() + .filter(user::Column::InviteCode.eq(code)) + .one(&tx) + .await? + .ok_or_else(|| { + Error::Http( + StatusCode::NOT_FOUND, + "that invite code does not exist".to_string(), + ) + }) }) .await } + // projects + pub async fn share_project( &self, - expected_room_id: RoomId, + room_id: RoomId, connection_id: ConnectionId, worktrees: &[proto::WorktreeMetadata], ) -> Result> { - self.transact(|mut tx| async move { - let (sql, values) = self.build_query( - Query::select() - .columns([ - schema::room_participant::Definition::RoomId, - schema::room_participant::Definition::UserId, - ]) - .from(schema::room_participant::Definition::Table) - .and_where( - Expr::col(schema::room_participant::Definition::AnsweringConnectionId) - .eq(connection_id.0), - ), - ); - let (room_id, user_id) = sqlx::query_as_with::<_, (RoomId, UserId), _>(&sql, values) - .fetch_one(&mut tx) - .await?; - if room_id != expected_room_id { + self.transact(|tx| async move { + let participant = room_participant::Entity::find() + .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0)) + .one(&tx) + .await? + .ok_or_else(|| anyhow!("could not find participant"))?; + if participant.room_id != room_id { return Err(anyhow!("shared project on unexpected room"))?; } - let (sql, values) = self.build_query( - Query::insert() - .into_table(schema::project::Definition::Table) - .columns([ - schema::project::Definition::RoomId, - schema::project::Definition::HostUserId, - schema::project::Definition::HostConnectionId, - ]) - .values_panic([room_id.into(), user_id.into(), connection_id.0.into()]) - .returning_col(schema::project::Definition::Id), - ); - let project_id: ProjectId = sqlx::query_scalar_with(&sql, values) - .fetch_one(&mut tx) - .await?; - - if !worktrees.is_empty() { - let mut query = Query::insert() - .into_table(schema::worktree::Definition::Table) - .columns([ - schema::worktree::Definition::ProjectId, - schema::worktree::Definition::Id, - schema::worktree::Definition::RootName, - schema::worktree::Definition::AbsPath, - schema::worktree::Definition::Visible, - schema::worktree::Definition::ScanId, - schema::worktree::Definition::IsComplete, - ]) - .to_owned(); - for worktree in worktrees { - query.values_panic([ - project_id.into(), - worktree.id.into(), - worktree.root_name.clone().into(), - worktree.abs_path.clone().into(), - worktree.visible.into(), - 0.into(), - false.into(), - ]); - } - let (sql, values) = self.build_query(&query); - sqlx::query_with(&sql, values).execute(&mut tx).await?; - } - - sqlx::query( - " - INSERT INTO project_collaborators ( - project_id, - connection_id, - user_id, - replica_id, - is_host - ) - VALUES ($1, $2, $3, $4, $5) - ", - ) - .bind(project_id) - .bind(connection_id.0 as i32) - .bind(user_id) - .bind(0) - .bind(true) - .execute(&mut tx) - .await?; - - let room = self.get_room(room_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, (project_id, room)) - .await - }) - .await - } - - pub async fn unshare_project( - &self, - project_id: ProjectId, - connection_id: ConnectionId, - ) -> Result)>> { - self.transact(|mut tx| async move { - let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; - let room_id: RoomId = sqlx::query_scalar( - " - DELETE FROM projects - WHERE id = $1 AND host_connection_id = $2 - RETURNING room_id - ", - ) - .bind(project_id) - .bind(connection_id.0 as i32) - .fetch_one(&mut tx) - .await?; - let room = self.get_room(room_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, (room, guest_connection_ids)) - .await - }) - .await - } - - pub async fn update_project( - &self, - project_id: ProjectId, - connection_id: ConnectionId, - worktrees: &[proto::WorktreeMetadata], - ) -> Result)>> { - self.transact(|mut tx| async move { - let room_id: RoomId = sqlx::query_scalar( - " - SELECT room_id - FROM projects - WHERE id = $1 AND host_connection_id = $2 - ", - ) - .bind(project_id) - .bind(connection_id.0 as i32) - .fetch_one(&mut tx) - .await?; - - if !worktrees.is_empty() { - let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len()); - params.pop(); - let query = format!( - " - INSERT INTO worktrees ( - project_id, - id, - root_name, - abs_path, - visible, - scan_id, - is_complete - ) - VALUES {params} - ON CONFLICT (project_id, id) DO UPDATE SET root_name = excluded.root_name - " - ); - - let mut query = sqlx::query(&query); - for worktree in worktrees { - query = query - .bind(project_id) - .bind(worktree.id as i32) - .bind(&worktree.root_name) - .bind(&worktree.abs_path) - .bind(worktree.visible) - .bind(0) - .bind(false) - } - query.execute(&mut tx).await?; - } - - let mut params = "?,".repeat(worktrees.len()); - if !worktrees.is_empty() { - params.pop(); - } - let query = format!( - " - DELETE FROM worktrees - WHERE project_id = ? AND id NOT IN ({params}) - ", - ); - - let mut query = sqlx::query(&query).bind(project_id); - for worktree in worktrees { - query = query.bind(WorktreeId(worktree.id as i32)); - } - query.execute(&mut tx).await?; - - let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; - let room = self.get_room(room_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, (room, guest_connection_ids)) - .await - }) - .await - } - - pub async fn update_worktree( - &self, - update: &proto::UpdateWorktree, - connection_id: ConnectionId, - ) -> Result>> { - self.transact(|mut tx| async move { - let project_id = ProjectId::from_proto(update.project_id); - let worktree_id = WorktreeId::from_proto(update.worktree_id); - - // Ensure the update comes from the host. - let room_id: RoomId = sqlx::query_scalar( - " - SELECT room_id - FROM projects - WHERE id = $1 AND host_connection_id = $2 - ", - ) - .bind(project_id) - .bind(connection_id.0 as i32) - .fetch_one(&mut tx) - .await?; - - // Update metadata. - sqlx::query( - " - UPDATE worktrees - SET - root_name = $1, - scan_id = $2, - is_complete = $3, - abs_path = $4 - WHERE project_id = $5 AND id = $6 - RETURNING 1 - ", - ) - .bind(&update.root_name) - .bind(update.scan_id as i64) - .bind(update.is_last_update) - .bind(&update.abs_path) - .bind(project_id) - .bind(worktree_id) - .fetch_one(&mut tx) - .await?; - - if !update.updated_entries.is_empty() { - let mut params = - "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?),".repeat(update.updated_entries.len()); - params.pop(); - - let query = format!( - " - INSERT INTO worktree_entries ( - project_id, - worktree_id, - id, - is_dir, - path, - inode, - mtime_seconds, - mtime_nanos, - is_symlink, - is_ignored - ) - VALUES {params} - ON CONFLICT (project_id, worktree_id, id) DO UPDATE SET - is_dir = excluded.is_dir, - path = excluded.path, - inode = excluded.inode, - mtime_seconds = excluded.mtime_seconds, - mtime_nanos = excluded.mtime_nanos, - is_symlink = excluded.is_symlink, - is_ignored = excluded.is_ignored - " - ); - let mut query = sqlx::query(&query); - for entry in &update.updated_entries { - let mtime = entry.mtime.clone().unwrap_or_default(); - query = query - .bind(project_id) - .bind(worktree_id) - .bind(entry.id as i64) - .bind(entry.is_dir) - .bind(&entry.path) - .bind(entry.inode as i64) - .bind(mtime.seconds as i64) - .bind(mtime.nanos as i32) - .bind(entry.is_symlink) - .bind(entry.is_ignored); - } - query.execute(&mut tx).await?; - } - - if !update.removed_entries.is_empty() { - let mut params = "?,".repeat(update.removed_entries.len()); - params.pop(); - let query = format!( - " - DELETE FROM worktree_entries - WHERE project_id = ? AND worktree_id = ? AND id IN ({params}) - " - ); - - let mut query = sqlx::query(&query).bind(project_id).bind(worktree_id); - for entry_id in &update.removed_entries { - query = query.bind(*entry_id as i64); - } - query.execute(&mut tx).await?; - } - - let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, connection_ids) - .await - }) - .await - } - - pub async fn update_diagnostic_summary( - &self, - update: &proto::UpdateDiagnosticSummary, - connection_id: ConnectionId, - ) -> Result>> { - self.transact(|mut tx| async { - let project_id = ProjectId::from_proto(update.project_id); - let worktree_id = WorktreeId::from_proto(update.worktree_id); - let summary = update - .summary - .as_ref() - .ok_or_else(|| anyhow!("invalid summary"))?; - - // Ensure the update comes from the host. - let room_id: RoomId = sqlx::query_scalar( - " - SELECT room_id - FROM projects - WHERE id = $1 AND host_connection_id = $2 - ", - ) - .bind(project_id) - .bind(connection_id.0 as i32) - .fetch_one(&mut tx) - .await?; - - // Update summary. - sqlx::query( - " - INSERT INTO worktree_diagnostic_summaries ( - project_id, - worktree_id, - path, - language_server_id, - error_count, - warning_count - ) - VALUES ($1, $2, $3, $4, $5, $6) - ON CONFLICT (project_id, worktree_id, path) DO UPDATE SET - language_server_id = excluded.language_server_id, - error_count = excluded.error_count, - warning_count = excluded.warning_count - ", - ) - .bind(project_id) - .bind(worktree_id) - .bind(&summary.path) - .bind(summary.language_server_id as i64) - .bind(summary.error_count as i32) - .bind(summary.warning_count as i32) - .execute(&mut tx) - .await?; - - let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, connection_ids) - .await - }) - .await - } - - pub async fn start_language_server( - &self, - update: &proto::StartLanguageServer, - connection_id: ConnectionId, - ) -> Result>> { - self.transact(|mut tx| async { - let project_id = ProjectId::from_proto(update.project_id); - let server = update - .server - .as_ref() - .ok_or_else(|| anyhow!("invalid language server"))?; - - // Ensure the update comes from the host. - let room_id: RoomId = sqlx::query_scalar( - " - SELECT room_id - FROM projects - WHERE id = $1 AND host_connection_id = $2 - ", - ) - .bind(project_id) - .bind(connection_id.0 as i32) - .fetch_one(&mut tx) - .await?; - - // Add the newly-started language server. - sqlx::query( - " - INSERT INTO language_servers (project_id, id, name) - VALUES ($1, $2, $3) - ON CONFLICT (project_id, id) DO UPDATE SET - name = excluded.name - ", - ) - .bind(project_id) - .bind(server.id as i64) - .bind(&server.name) - .execute(&mut tx) - .await?; - - let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, connection_ids) - .await - }) - .await - } - - pub async fn join_project( - &self, - project_id: ProjectId, - connection_id: ConnectionId, - ) -> Result> { - self.transact(|mut tx| async move { - let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>( - " - SELECT room_id, user_id - FROM room_participants - WHERE answering_connection_id = $1 - ", - ) - .bind(connection_id.0 as i32) - .fetch_one(&mut tx) - .await?; - - // Ensure project id was shared on this room. - sqlx::query( - " - SELECT 1 - FROM projects - WHERE id = $1 AND room_id = $2 - ", - ) - .bind(project_id) - .bind(room_id) - .fetch_one(&mut tx) - .await?; - - let mut collaborators = sqlx::query_as::<_, ProjectCollaborator>( - " - SELECT * - FROM project_collaborators - WHERE project_id = $1 - ", - ) - .bind(project_id) - .fetch_all(&mut tx) - .await?; - let replica_ids = collaborators - .iter() - .map(|c| c.replica_id) - .collect::>(); - let mut replica_id = ReplicaId(1); - while replica_ids.contains(&replica_id) { - replica_id.0 += 1; - } - let new_collaborator = ProjectCollaborator { - project_id, - connection_id: connection_id.0 as i32, - user_id, - replica_id, - is_host: false, - }; - - sqlx::query( - " - INSERT INTO project_collaborators ( - project_id, - connection_id, - user_id, - replica_id, - is_host - ) - VALUES ($1, $2, $3, $4, $5) - ", - ) - .bind(new_collaborator.project_id) - .bind(new_collaborator.connection_id) - .bind(new_collaborator.user_id) - .bind(new_collaborator.replica_id) - .bind(new_collaborator.is_host) - .execute(&mut tx) - .await?; - collaborators.push(new_collaborator); - - let worktree_rows = sqlx::query_as::<_, WorktreeRow>( - " - SELECT * - FROM worktrees - WHERE project_id = $1 - ", - ) - .bind(project_id) - .fetch_all(&mut tx) - .await?; - let mut worktrees = worktree_rows - .into_iter() - .map(|worktree_row| { - ( - worktree_row.id, - Worktree { - id: worktree_row.id, - abs_path: worktree_row.abs_path, - root_name: worktree_row.root_name, - visible: worktree_row.visible, - entries: Default::default(), - diagnostic_summaries: Default::default(), - scan_id: worktree_row.scan_id as u64, - is_complete: worktree_row.is_complete, - }, - ) - }) - .collect::>(); - - // Populate worktree entries. - { - let mut entries = sqlx::query_as::<_, WorktreeEntry>( - " - SELECT * - FROM worktree_entries - WHERE project_id = $1 - ", - ) - .bind(project_id) - .fetch(&mut tx); - while let Some(entry) = entries.next().await { - let entry = entry?; - if let Some(worktree) = worktrees.get_mut(&entry.worktree_id) { - worktree.entries.push(proto::Entry { - id: entry.id as u64, - is_dir: entry.is_dir, - path: entry.path, - inode: entry.inode as u64, - mtime: Some(proto::Timestamp { - seconds: entry.mtime_seconds as u64, - nanos: entry.mtime_nanos as u32, - }), - is_symlink: entry.is_symlink, - is_ignored: entry.is_ignored, - }); - } - } - } - - // Populate worktree diagnostic summaries. - { - let mut summaries = sqlx::query_as::<_, WorktreeDiagnosticSummary>( - " - SELECT * - FROM worktree_diagnostic_summaries - WHERE project_id = $1 - ", - ) - .bind(project_id) - .fetch(&mut tx); - while let Some(summary) = summaries.next().await { - let summary = summary?; - if let Some(worktree) = worktrees.get_mut(&summary.worktree_id) { - worktree - .diagnostic_summaries - .push(proto::DiagnosticSummary { - path: summary.path, - language_server_id: summary.language_server_id as u64, - error_count: summary.error_count as u32, - warning_count: summary.warning_count as u32, - }); - } - } - } - - // Populate language servers. - let language_servers = sqlx::query_as::<_, LanguageServer>( - " - SELECT * - FROM language_servers - WHERE project_id = $1 - ", - ) - .bind(project_id) - .fetch_all(&mut tx) - .await?; - - self.commit_room_transaction( - room_id, - tx, - ( - Project { - collaborators, - worktrees, - language_servers: language_servers - .into_iter() - .map(|language_server| proto::LanguageServer { - id: language_server.id.to_proto(), - name: language_server.name, - }) - .collect(), - }, - replica_id as ReplicaId, - ), - ) - .await - }) - .await - } - - pub async fn leave_project( - &self, - project_id: ProjectId, - connection_id: ConnectionId, - ) -> Result> { - self.transact(|mut tx| async move { - let result = sqlx::query( - " - DELETE FROM project_collaborators - WHERE project_id = $1 AND connection_id = $2 - ", - ) - .bind(project_id) - .bind(connection_id.0 as i32) - .execute(&mut tx) - .await?; - - if result.rows_affected() == 0 { - Err(anyhow!("not a collaborator on this project"))?; - } - - let connection_ids = sqlx::query_scalar::<_, i32>( - " - SELECT connection_id - FROM project_collaborators - WHERE project_id = $1 - ", - ) - .bind(project_id) - .fetch_all(&mut tx) - .await? - .into_iter() - .map(|id| ConnectionId(id as u32)) - .collect(); - - let (room_id, host_user_id, host_connection_id) = - sqlx::query_as::<_, (RoomId, i32, i32)>( - " - SELECT room_id, host_user_id, host_connection_id - FROM projects - WHERE id = $1 - ", - ) - .bind(project_id) - .fetch_one(&mut tx) - .await?; - - self.commit_room_transaction( - room_id, - tx, - LeftProject { - id: project_id, - host_user_id: UserId(host_user_id), - host_connection_id: ConnectionId(host_connection_id as u32), - connection_ids, - }, - ) - .await - }) - .await - } - - pub async fn project_collaborators( - &self, - project_id: ProjectId, - connection_id: ConnectionId, - ) -> Result> { - self.transact(|mut tx| async move { - let collaborators = sqlx::query_as::<_, ProjectCollaborator>( - " - SELECT * - FROM project_collaborators - WHERE project_id = $1 - ", - ) - .bind(project_id) - .fetch_all(&mut tx) - .await?; - - if collaborators - .iter() - .any(|collaborator| collaborator.connection_id == connection_id.0 as i32) - { - Ok(collaborators) - } else { - Err(anyhow!("no such project"))? - } - }) - .await - } - - pub async fn project_connection_ids( - &self, - project_id: ProjectId, - connection_id: ConnectionId, - ) -> Result> { - self.transact(|mut tx| async move { - let connection_ids = sqlx::query_scalar::<_, i32>( - " - SELECT connection_id - FROM project_collaborators - WHERE project_id = $1 - ", - ) - .bind(project_id) - .fetch_all(&mut tx) - .await?; - - if connection_ids.contains(&(connection_id.0 as i32)) { - Ok(connection_ids - .into_iter() - .map(|connection_id| ConnectionId(connection_id as u32)) - .collect()) - } else { - Err(anyhow!("no such project"))? - } - }) - .await - } - - // contacts - - pub async fn get_contacts(&self, user_id: UserId) -> Result> { - self.transact(|mut tx| async move { - let query = " - SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify, (room_participants.id IS NOT NULL) as busy - FROM contacts - LEFT JOIN room_participants ON room_participants.user_id = $1 - WHERE user_id_a = $1 OR user_id_b = $1; - "; - - let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool, bool)>(query) - .bind(user_id) - .fetch(&mut tx); - - let mut contacts = Vec::new(); - while let Some(row) = rows.next().await { - let (user_id_a, user_id_b, a_to_b, accepted, should_notify, busy) = row?; - if user_id_a == user_id { - if accepted { - contacts.push(Contact::Accepted { - user_id: user_id_b, - should_notify: should_notify && a_to_b, - busy - }); - } else if a_to_b { - contacts.push(Contact::Outgoing { user_id: user_id_b }) - } else { - contacts.push(Contact::Incoming { - user_id: user_id_b, - should_notify, - }); - } - } else if accepted { - contacts.push(Contact::Accepted { - user_id: user_id_a, - should_notify: should_notify && !a_to_b, - busy - }); - } else if a_to_b { - contacts.push(Contact::Incoming { - user_id: user_id_a, - should_notify, - }); - } else { - contacts.push(Contact::Outgoing { user_id: user_id_a }); - } - } - - contacts.sort_unstable_by_key(|contact| contact.user_id()); - - Ok(contacts) - }) - .await - } - - pub async fn is_user_busy(&self, user_id: UserId) -> Result { - self.transact(|mut tx| async move { - Ok(sqlx::query_scalar::<_, i32>( - " - SELECT 1 - FROM room_participants - WHERE room_participants.user_id = $1 - ", - ) - .bind(user_id) - .fetch_optional(&mut tx) - .await? - .is_some()) - }) - .await - } - - pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result { - self.transact(|mut tx| async move { - let (id_a, id_b) = if user_id_1 < user_id_2 { - (user_id_1, user_id_2) - } else { - (user_id_2, user_id_1) - }; - - let query = " - SELECT 1 FROM contacts - WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = TRUE - LIMIT 1 - "; - Ok(sqlx::query_scalar::<_, i32>(query) - .bind(id_a.0) - .bind(id_b.0) - .fetch_optional(&mut tx) - .await? - .is_some()) - }) - .await - } - - pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { - self.transact(|mut tx| async move { - let (id_a, id_b, a_to_b) = if sender_id < receiver_id { - (sender_id, receiver_id, true) - } else { - (receiver_id, sender_id, false) - }; - let query = " - INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify) - VALUES ($1, $2, $3, FALSE, TRUE) - ON CONFLICT (user_id_a, user_id_b) DO UPDATE - SET - accepted = TRUE, - should_notify = FALSE - WHERE - NOT contacts.accepted AND - ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR - (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a)); - "; - let result = sqlx::query(query) - .bind(id_a.0) - .bind(id_b.0) - .bind(a_to_b) - .execute(&mut tx) - .await?; - - if result.rows_affected() == 1 { - tx.commit().await?; - Ok(()) - } else { - Err(anyhow!("contact already requested"))? - } - }).await - } - - pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> { - self.transact(|mut tx| async move { - let (id_a, id_b) = if responder_id < requester_id { - (responder_id, requester_id) - } else { - (requester_id, responder_id) - }; - let query = " - DELETE FROM contacts - WHERE user_id_a = $1 AND user_id_b = $2; - "; - let result = sqlx::query(query) - .bind(id_a.0) - .bind(id_b.0) - .execute(&mut tx) - .await?; - - if result.rows_affected() == 1 { - tx.commit().await?; - Ok(()) - } else { - Err(anyhow!("no such contact"))? - } - }) - .await - } - - pub async fn dismiss_contact_notification( - &self, - user_id: UserId, - contact_user_id: UserId, - ) -> Result<()> { - self.transact(|mut tx| async move { - let (id_a, id_b, a_to_b) = if user_id < contact_user_id { - (user_id, contact_user_id, true) - } else { - (contact_user_id, user_id, false) - }; - - let query = " - UPDATE contacts - SET should_notify = FALSE - WHERE - user_id_a = $1 AND user_id_b = $2 AND - ( - (a_to_b = $3 AND accepted) OR - (a_to_b != $3 AND NOT accepted) - ); - "; + let project = project::ActiveModel { + room_id: ActiveValue::set(participant.room_id), + host_user_id: ActiveValue::set(participant.user_id), + host_connection_id: ActiveValue::set(connection_id.0 as i32), + ..Default::default() + } + .insert(&tx) + .await?; - let result = sqlx::query(query) - .bind(id_a.0) - .bind(id_b.0) - .bind(a_to_b) - .execute(&mut tx) - .await?; + worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel { + id: ActiveValue::set(worktree.id as i32), + project_id: ActiveValue::set(project.id), + abs_path: ActiveValue::set(worktree.abs_path.clone()), + root_name: ActiveValue::set(worktree.root_name.clone()), + visible: ActiveValue::set(worktree.visible), + scan_id: ActiveValue::set(0), + is_complete: ActiveValue::set(false), + })) + .exec(&tx) + .await?; - if result.rows_affected() == 0 { - Err(anyhow!("no such contact request"))? - } else { - tx.commit().await?; - Ok(()) + project_collaborator::ActiveModel { + project_id: ActiveValue::set(project.id), + connection_id: ActiveValue::set(connection_id.0 as i32), + user_id: ActiveValue::set(participant.user_id), + replica_id: ActiveValue::set(0), + is_host: ActiveValue::set(true), + ..Default::default() } + .insert(&tx) + .await?; + + let room = self.get_room(room_id, &tx).await?; + self.commit_room_transaction(room_id, tx, (project.id, room)) + .await }) .await } - pub async fn respond_to_contact_request( - &self, - responder_id: UserId, - requester_id: UserId, - accept: bool, - ) -> Result<()> { - self.transact(|mut tx| async move { - let (id_a, id_b, a_to_b) = if responder_id < requester_id { - (responder_id, requester_id, false) - } else { - (requester_id, responder_id, true) - }; - let result = if accept { - let query = " - UPDATE contacts - SET accepted = TRUE, should_notify = TRUE - WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3; - "; - sqlx::query(query) - .bind(id_a.0) - .bind(id_b.0) - .bind(a_to_b) - .execute(&mut tx) - .await? - } else { - let query = " - DELETE FROM contacts - WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted; - "; - sqlx::query(query) - .bind(id_a.0) - .bind(id_b.0) - .bind(a_to_b) - .execute(&mut tx) - .await? - }; - if result.rows_affected() == 1 { - tx.commit().await?; - Ok(()) + async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result { + let db_room = room::Entity::find_by_id(room_id) + .one(tx) + .await? + .ok_or_else(|| anyhow!("could not find room"))?; + + let mut db_participants = db_room + .find_related(room_participant::Entity) + .stream(tx) + .await?; + let mut participants = HashMap::default(); + let mut pending_participants = Vec::new(); + while let Some(db_participant) = db_participants.next().await { + let db_participant = db_participant?; + if let Some(answering_connection_id) = db_participant.answering_connection_id { + let location = match ( + db_participant.location_kind, + db_participant.location_project_id, + ) { + (Some(0), Some(project_id)) => { + Some(proto::participant_location::Variant::SharedProject( + proto::participant_location::SharedProject { + id: project_id.to_proto(), + }, + )) + } + (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject( + Default::default(), + )), + _ => Some(proto::participant_location::Variant::External( + Default::default(), + )), + }; + participants.insert( + answering_connection_id, + proto::Participant { + user_id: db_participant.user_id.to_proto(), + peer_id: answering_connection_id as u32, + projects: Default::default(), + location: Some(proto::ParticipantLocation { variant: location }), + }, + ); } else { - Err(anyhow!("no such contact request"))? + pending_participants.push(proto::PendingParticipant { + user_id: db_participant.user_id.to_proto(), + calling_user_id: db_participant.calling_user_id.to_proto(), + initial_project_id: db_participant.initial_project_id.map(|id| id.to_proto()), + }); + } + } + + let mut db_projects = db_room + .find_related(project::Entity) + .find_with_related(worktree::Entity) + .stream(tx) + .await?; + + while let Some(row) = db_projects.next().await { + let (db_project, db_worktree) = row?; + if let Some(participant) = participants.get_mut(&db_project.host_connection_id) { + let project = if let Some(project) = participant + .projects + .iter_mut() + .find(|project| project.id == db_project.id.to_proto()) + { + project + } else { + participant.projects.push(proto::ParticipantProject { + id: db_project.id.to_proto(), + worktree_root_names: Default::default(), + }); + participant.projects.last_mut().unwrap() + }; + + if let Some(db_worktree) = db_worktree { + project.worktree_root_names.push(db_worktree.root_name); + } } + } + + Ok(proto::Room { + id: db_room.id.to_proto(), + live_kit_room: db_room.live_kit_room, + participants: participants.into_values().collect(), + pending_participants, }) - .await } - // access tokens + async fn commit_room_transaction( + &self, + room_id: RoomId, + tx: DatabaseTransaction, + data: T, + ) -> Result> { + let lock = self.rooms.entry(room_id).or_default().clone(); + let _guard = lock.lock_owned().await; + tx.commit().await?; + Ok(RoomGuard { + data, + _guard, + _not_send: PhantomData, + }) + } pub async fn create_access_token_hash( &self, @@ -2545,49 +1064,51 @@ where max_access_token_count: usize, ) -> Result<()> { self.transact(|tx| async { - let mut tx = tx; - let insert_query = " - INSERT INTO access_tokens (user_id, hash) - VALUES ($1, $2); - "; - let cleanup_query = " - DELETE FROM access_tokens - WHERE id IN ( - SELECT id from access_tokens - WHERE user_id = $1 - ORDER BY id DESC - LIMIT 10000 - OFFSET $3 - ) - "; + let tx = tx; - sqlx::query(insert_query) - .bind(user_id.0) - .bind(access_token_hash) - .execute(&mut tx) - .await?; - sqlx::query(cleanup_query) - .bind(user_id.0) - .bind(access_token_hash) - .bind(max_access_token_count as i32) - .execute(&mut tx) + access_token::ActiveModel { + user_id: ActiveValue::set(user_id), + hash: ActiveValue::set(access_token_hash.into()), + ..Default::default() + } + .insert(&tx) + .await?; + + access_token::Entity::delete_many() + .filter( + access_token::Column::Id.in_subquery( + Query::select() + .column(access_token::Column::Id) + .from(access_token::Entity) + .and_where(access_token::Column::UserId.eq(user_id)) + .order_by(access_token::Column::Id, sea_orm::Order::Desc) + .limit(10000) + .offset(max_access_token_count as u64) + .to_owned(), + ), + ) + .exec(&tx) .await?; - Ok(tx.commit().await?) + tx.commit().await?; + Ok(()) }) .await } pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result> { - self.transact(|mut tx| async move { - let query = " - SELECT hash - FROM access_tokens - WHERE user_id = $1 - ORDER BY id DESC - "; - Ok(sqlx::query_scalar(query) - .bind(user_id.0) - .fetch_all(&mut tx) + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryAs { + Hash, + } + + self.transact(|tx| async move { + Ok(access_token::Entity::find() + .select_only() + .column(access_token::Column::Hash) + .filter(access_token::Column::UserId.eq(user_id)) + .order_by_desc(access_token::Column::Id) + .into_values::<_, QueryAs>() + .all(&tx) .await?) }) .await @@ -2595,21 +1116,33 @@ where async fn transact(&self, f: F) -> Result where - F: Send + Fn(sqlx::Transaction<'static, D>) -> Fut, + F: Send + Fn(DatabaseTransaction) -> Fut, Fut: Send + Future>, { let body = async { loop { - let tx = self.begin_transaction().await?; + let tx = self.pool.begin().await?; + + // In Postgres, serializable transactions are opt-in + if let DatabaseBackend::Postgres = self.pool.get_database_backend() { + tx.execute(Statement::from_string( + DatabaseBackend::Postgres, + "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(), + )) + .await?; + } + match f(tx).await { Ok(result) => return Ok(result), Err(error) => match error { - Error::Database(error) - if error - .as_database_error() - .and_then(|error| error.code()) - .as_deref() - == Some("40001") => + Error::Database2( + DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error)) + | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)), + ) if error + .as_database_error() + .and_then(|error| error.code()) + .as_deref() + == Some("40001") => { // Retry (don't break the loop) } @@ -2635,6 +1168,49 @@ where } } +pub struct RoomGuard { + data: T, + _guard: OwnedMutexGuard<()>, + _not_send: PhantomData>, +} + +impl Deref for RoomGuard { + type Target = T; + + fn deref(&self) -> &T { + &self.data + } +} + +impl DerefMut for RoomGuard { + fn deref_mut(&mut self) -> &mut T { + &mut self.data + } +} + +#[derive(Debug, Serialize, Deserialize)] +pub struct NewUserParams { + pub github_login: String, + pub github_user_id: i32, + pub invite_count: i32, +} + +#[derive(Debug)] +pub struct NewUserResult { + pub user_id: UserId, + pub metrics_id: String, + pub inviting_user_id: Option, + pub signup_device_id: Option, +} + +fn random_invite_code() -> String { + nanoid::nanoid!(16) +} + +fn random_email_confirmation_code() -> String { + nanoid::nanoid!(64) +} + macro_rules! id_type { ($name:ident) => { #[derive( @@ -2681,196 +1257,90 @@ macro_rules! id_type { sea_query::Value::Int(Some(value.0)) } } - }; -} - -id_type!(UserId); -#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)] -pub struct User { - pub id: UserId, - pub github_login: String, - pub github_user_id: Option, - pub email_address: Option, - pub admin: bool, - pub invite_code: Option, - pub invite_count: i32, - pub connected_once: bool, -} - -id_type!(RoomId); -#[derive(Clone, Debug, Default, FromRow, Serialize, PartialEq)] -pub struct Room { - pub id: RoomId, - pub live_kit_room: String, -} - -id_type!(ProjectId); -pub struct Project { - pub collaborators: Vec, - pub worktrees: BTreeMap, - pub language_servers: Vec, -} - -id_type!(ReplicaId); -#[derive(Clone, Debug, Default, FromRow, PartialEq)] -pub struct ProjectCollaborator { - pub project_id: ProjectId, - pub connection_id: i32, - pub user_id: UserId, - pub replica_id: ReplicaId, - pub is_host: bool, -} - -id_type!(WorktreeId); -#[derive(Clone, Debug, Default, FromRow, PartialEq)] -struct WorktreeRow { - pub id: WorktreeId, - pub project_id: ProjectId, - pub abs_path: String, - pub root_name: String, - pub visible: bool, - pub scan_id: i64, - pub is_complete: bool, -} - -pub struct Worktree { - pub id: WorktreeId, - pub abs_path: String, - pub root_name: String, - pub visible: bool, - pub entries: Vec, - pub diagnostic_summaries: Vec, - pub scan_id: u64, - pub is_complete: bool, -} - -#[derive(Clone, Debug, Default, FromRow, PartialEq)] -struct WorktreeEntry { - id: i64, - worktree_id: WorktreeId, - is_dir: bool, - path: String, - inode: i64, - mtime_seconds: i64, - mtime_nanos: i32, - is_symlink: bool, - is_ignored: bool, -} - -#[derive(Clone, Debug, Default, FromRow, PartialEq)] -struct WorktreeDiagnosticSummary { - worktree_id: WorktreeId, - path: String, - language_server_id: i64, - error_count: i32, - warning_count: i32, -} - -id_type!(LanguageServerId); -#[derive(Clone, Debug, Default, FromRow, PartialEq)] -struct LanguageServer { - id: LanguageServerId, - name: String, -} - -pub struct LeftProject { - pub id: ProjectId, - pub host_user_id: UserId, - pub host_connection_id: ConnectionId, - pub connection_ids: Vec, -} - -pub struct LeftRoom { - pub room: proto::Room, - pub left_projects: HashMap, - pub canceled_calls_to_user_ids: Vec, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum Contact { - Accepted { - user_id: UserId, - should_notify: bool, - busy: bool, - }, - Outgoing { - user_id: UserId, - }, - Incoming { - user_id: UserId, - should_notify: bool, - }, -} -impl Contact { - pub fn user_id(&self) -> UserId { - match self { - Contact::Accepted { user_id, .. } => *user_id, - Contact::Outgoing { user_id } => *user_id, - Contact::Incoming { user_id, .. } => *user_id, + impl sea_orm::TryGetable for $name { + fn try_get( + res: &sea_orm::QueryResult, + pre: &str, + col: &str, + ) -> Result { + Ok(Self(i32::try_get(res, pre, col)?)) + } } - } -} - -#[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct IncomingContactRequest { - pub requester_id: UserId, - pub should_notify: bool, -} -#[derive(Clone, Deserialize)] -pub struct Signup { - pub email_address: String, - pub platform_mac: bool, - pub platform_windows: bool, - pub platform_linux: bool, - pub editor_features: Vec, - pub programming_languages: Vec, - pub device_id: Option, -} + impl sea_query::ValueType for $name { + fn try_from(v: Value) -> Result { + match v { + Value::TinyInt(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::SmallInt(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::Int(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::BigInt(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::TinyUnsigned(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::SmallUnsigned(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::Unsigned(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + Value::BigUnsigned(Some(int)) => { + Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) + } + _ => Err(sea_query::ValueTypeErr), + } + } -#[derive(Clone, Debug, PartialEq, Deserialize, Serialize, FromRow)] -pub struct WaitlistSummary { - #[sqlx(default)] - pub count: i64, - #[sqlx(default)] - pub linux_count: i64, - #[sqlx(default)] - pub mac_count: i64, - #[sqlx(default)] - pub windows_count: i64, - #[sqlx(default)] - pub unknown_count: i64, -} + fn type_name() -> String { + stringify!($name).into() + } -#[derive(FromRow, PartialEq, Debug, Serialize, Deserialize)] -pub struct Invite { - pub email_address: String, - pub email_confirmation_code: String, -} + fn array_type() -> sea_query::ArrayType { + sea_query::ArrayType::Int + } -#[derive(Debug, Serialize, Deserialize)] -pub struct NewUserParams { - pub github_login: String, - pub github_user_id: i32, - pub invite_count: i32, -} + fn column_type() -> sea_query::ColumnType { + sea_query::ColumnType::Integer(None) + } + } -#[derive(Debug)] -pub struct NewUserResult { - pub user_id: UserId, - pub metrics_id: String, - pub inviting_user_id: Option, - pub signup_device_id: Option, -} + impl sea_orm::TryFromU64 for $name { + fn try_from_u64(n: u64) -> Result { + Ok(Self(n.try_into().map_err(|_| { + DbErr::ConvertFromU64(concat!( + "error converting ", + stringify!($name), + " to u64" + )) + })?)) + } + } -fn random_invite_code() -> String { - nanoid::nanoid!(16) + impl sea_query::Nullable for $name { + fn null() -> Value { + Value::Int(None) + } + } + }; } -fn random_email_confirmation_code() -> String { - nanoid::nanoid!(64) -} +id_type!(AccessTokenId); +id_type!(ContactId); +id_type!(UserId); +id_type!(RoomId); +id_type!(RoomParticipantId); +id_type!(ProjectId); +id_type!(ProjectCollaboratorId); +id_type!(SignupId); +id_type!(WorktreeId); #[cfg(test)] pub use test::*; @@ -2882,35 +1352,40 @@ mod test { use lazy_static::lazy_static; use parking_lot::Mutex; use rand::prelude::*; + use sea_orm::ConnectionTrait; use sqlx::migrate::MigrateDatabase; use std::sync::Arc; - pub struct SqliteTestDb { - pub db: Option>>, - pub conn: sqlx::sqlite::SqliteConnection, - } - - pub struct PostgresTestDb { - pub db: Option>>, - pub url: String, + pub struct TestDb { + pub db: Option>, + pub connection: Option, } - impl SqliteTestDb { - pub fn new(background: Arc) -> Self { - let mut rng = StdRng::from_entropy(); - let url = format!("file:zed-test-{}?mode=memory", rng.gen::()); + impl TestDb { + pub fn sqlite(background: Arc) -> Self { + let url = format!("sqlite::memory:"); let runtime = tokio::runtime::Builder::new_current_thread() .enable_io() .enable_time() .build() .unwrap(); - let (mut db, conn) = runtime.block_on(async { - let db = Db::::new(&url, 5).await.unwrap(); - let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations.sqlite"); - db.migrate(migrations_path.as_ref(), false).await.unwrap(); - let conn = db.pool.acquire().await.unwrap().detach(); - (db, conn) + let mut db = runtime.block_on(async { + let mut options = ConnectOptions::new(url); + options.max_connections(5); + let db = Database::new(options).await.unwrap(); + let sql = include_str!(concat!( + env!("CARGO_MANIFEST_DIR"), + "/migrations.sqlite/20221109000000_test_schema.sql" + )); + db.pool + .execute(sea_orm::Statement::from_string( + db.pool.get_database_backend(), + sql.into(), + )) + .await + .unwrap(); + db }); db.background = Some(background); @@ -2918,17 +1393,11 @@ mod test { Self { db: Some(Arc::new(db)), - conn, + connection: None, } } - pub fn db(&self) -> &Arc> { - self.db.as_ref().unwrap() - } - } - - impl PostgresTestDb { - pub fn new(background: Arc) -> Self { + pub fn postgres(background: Arc) -> Self { lazy_static! { static ref LOCK: Mutex<()> = Mutex::new(()); } @@ -2949,7 +1418,11 @@ mod test { sqlx::Postgres::create_database(&url) .await .expect("failed to create test db"); - let db = Db::::new(&url, 5).await.unwrap(); + let mut options = ConnectOptions::new(url); + options + .max_connections(5) + .idle_timeout(Duration::from_secs(0)); + let db = Database::new(options).await.unwrap(); let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"); db.migrate(Path::new(migrations_path), false).await.unwrap(); db @@ -2960,19 +1433,40 @@ mod test { Self { db: Some(Arc::new(db)), - url, + connection: None, } } - pub fn db(&self) -> &Arc> { + pub fn db(&self) -> &Arc { self.db.as_ref().unwrap() } } - impl Drop for PostgresTestDb { + impl Drop for TestDb { fn drop(&mut self) { let db = self.db.take().unwrap(); - db.teardown(&self.url); + if let DatabaseBackend::Postgres = db.pool.get_database_backend() { + db.runtime.as_ref().unwrap().block_on(async { + use util::ResultExt; + let query = " + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE + pg_stat_activity.datname = current_database() AND + pid <> pg_backend_pid(); + "; + db.pool + .execute(sea_orm::Statement::from_string( + db.pool.get_database_backend(), + query.into(), + )) + .await + .log_err(); + sqlx::Postgres::drop_database(db.options.get_url()) + .await + .log_err(); + }) + } } } } diff --git a/crates/collab/src/db2/access_token.rs b/crates/collab/src/db/access_token.rs similarity index 100% rename from crates/collab/src/db2/access_token.rs rename to crates/collab/src/db/access_token.rs diff --git a/crates/collab/src/db2/contact.rs b/crates/collab/src/db/contact.rs similarity index 100% rename from crates/collab/src/db2/contact.rs rename to crates/collab/src/db/contact.rs diff --git a/crates/collab/src/db2/project.rs b/crates/collab/src/db/project.rs similarity index 100% rename from crates/collab/src/db2/project.rs rename to crates/collab/src/db/project.rs diff --git a/crates/collab/src/db2/project_collaborator.rs b/crates/collab/src/db/project_collaborator.rs similarity index 100% rename from crates/collab/src/db2/project_collaborator.rs rename to crates/collab/src/db/project_collaborator.rs diff --git a/crates/collab/src/db2/room.rs b/crates/collab/src/db/room.rs similarity index 100% rename from crates/collab/src/db2/room.rs rename to crates/collab/src/db/room.rs diff --git a/crates/collab/src/db2/room_participant.rs b/crates/collab/src/db/room_participant.rs similarity index 100% rename from crates/collab/src/db2/room_participant.rs rename to crates/collab/src/db/room_participant.rs diff --git a/crates/collab/src/db/schema.rs b/crates/collab/src/db/schema.rs deleted file mode 100644 index 40a3e334d19bf483302beab702ca4038500d0138..0000000000000000000000000000000000000000 --- a/crates/collab/src/db/schema.rs +++ /dev/null @@ -1,43 +0,0 @@ -pub mod project { - use sea_query::Iden; - - #[derive(Iden)] - pub enum Definition { - #[iden = "projects"] - Table, - Id, - RoomId, - HostUserId, - HostConnectionId, - } -} - -pub mod worktree { - use sea_query::Iden; - - #[derive(Iden)] - pub enum Definition { - #[iden = "worktrees"] - Table, - Id, - ProjectId, - AbsPath, - RootName, - Visible, - ScanId, - IsComplete, - } -} - -pub mod room_participant { - use sea_query::Iden; - - #[derive(Iden)] - pub enum Definition { - #[iden = "room_participants"] - Table, - RoomId, - UserId, - AnsweringConnectionId, - } -} diff --git a/crates/collab/src/db2/signup.rs b/crates/collab/src/db/signup.rs similarity index 95% rename from crates/collab/src/db2/signup.rs rename to crates/collab/src/db/signup.rs index 8fab8daa3621ebe93a08ed74fc02c47a7fdfae61..9857018a0c9bda338109428e14f7b2ee79b30e31 100644 --- a/crates/collab/src/db2/signup.rs +++ b/crates/collab/src/db/signup.rs @@ -27,7 +27,7 @@ pub enum Relation {} impl ActiveModelBehavior for ActiveModel {} -#[derive(Debug, PartialEq, Eq, FromQueryResult)] +#[derive(Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)] pub struct Invite { pub email_address: String, pub email_confirmation_code: String, diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index 88488b10d26fda779611d698e608abcabc6ca688..b276bd5057b7282815a4c21eeea00fd691eecff5 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -6,14 +6,14 @@ macro_rules! test_both_dbs { ($postgres_test_name:ident, $sqlite_test_name:ident, $db:ident, $body:block) => { #[gpui::test] async fn $postgres_test_name() { - let test_db = PostgresTestDb::new(Deterministic::new(0).build_background()); + let test_db = TestDb::postgres(Deterministic::new(0).build_background()); let $db = test_db.db(); $body } #[gpui::test] async fn $sqlite_test_name() { - let test_db = SqliteTestDb::new(Deterministic::new(0).build_background()); + let test_db = TestDb::sqlite(Deterministic::new(0).build_background()); let $db = test_db.db(); $body } @@ -26,9 +26,10 @@ test_both_dbs!( db, { let mut user_ids = Vec::new(); + let mut user_metric_ids = Vec::new(); for i in 1..=4 { - user_ids.push( - db.create_user( + let user = db + .create_user( &format!("user{i}@example.com"), false, NewUserParams { @@ -38,9 +39,9 @@ test_both_dbs!( }, ) .await - .unwrap() - .user_id, - ); + .unwrap(); + user_ids.push(user.user_id); + user_metric_ids.push(user.metrics_id); } assert_eq!( @@ -52,6 +53,7 @@ test_both_dbs!( github_user_id: Some(1), email_address: Some("user1@example.com".to_string()), admin: false, + metrics_id: user_metric_ids[0].parse().unwrap(), ..Default::default() }, User { @@ -60,6 +62,7 @@ test_both_dbs!( github_user_id: Some(2), email_address: Some("user2@example.com".to_string()), admin: false, + metrics_id: user_metric_ids[1].parse().unwrap(), ..Default::default() }, User { @@ -68,6 +71,7 @@ test_both_dbs!( github_user_id: Some(3), email_address: Some("user3@example.com".to_string()), admin: false, + metrics_id: user_metric_ids[2].parse().unwrap(), ..Default::default() }, User { @@ -76,6 +80,7 @@ test_both_dbs!( github_user_id: Some(4), email_address: Some("user4@example.com".to_string()), admin: false, + metrics_id: user_metric_ids[3].parse().unwrap(), ..Default::default() } ] @@ -399,14 +404,14 @@ test_both_dbs!(test_metrics_id_postgres, test_metrics_id_sqlite, db, { #[test] fn test_fuzzy_like_string() { - assert_eq!(DefaultDb::fuzzy_like_string("abcd"), "%a%b%c%d%"); - assert_eq!(DefaultDb::fuzzy_like_string("x y"), "%x%y%"); - assert_eq!(DefaultDb::fuzzy_like_string(" z "), "%z%"); + assert_eq!(Database::fuzzy_like_string("abcd"), "%a%b%c%d%"); + assert_eq!(Database::fuzzy_like_string("x y"), "%x%y%"); + assert_eq!(Database::fuzzy_like_string(" z "), "%z%"); } #[gpui::test] async fn test_fuzzy_search_users() { - let test_db = PostgresTestDb::new(build_background_executor()); + let test_db = TestDb::postgres(build_background_executor()); let db = test_db.db(); for (i, github_login) in [ "California", @@ -442,7 +447,7 @@ async fn test_fuzzy_search_users() { &["rhode-island", "colorado", "oregon"], ); - async fn fuzzy_search_user_names(db: &Db, query: &str) -> Vec { + async fn fuzzy_search_user_names(db: &Database, query: &str) -> Vec { db.fuzzy_search_users(query, 10) .await .unwrap() @@ -454,7 +459,7 @@ async fn test_fuzzy_search_users() { #[gpui::test] async fn test_invite_codes() { - let test_db = PostgresTestDb::new(build_background_executor()); + let test_db = TestDb::postgres(build_background_executor()); let db = test_db.db(); let NewUserResult { user_id: user1, .. } = db @@ -659,12 +664,12 @@ async fn test_invite_codes() { #[gpui::test] async fn test_signups() { - let test_db = PostgresTestDb::new(build_background_executor()); + let test_db = TestDb::postgres(build_background_executor()); let db = test_db.db(); // people sign up on the waitlist for i in 0..8 { - db.create_signup(Signup { + db.create_signup(NewSignup { email_address: format!("person-{i}@example.com"), platform_mac: true, platform_linux: i % 2 == 0, diff --git a/crates/collab/src/db2/user.rs b/crates/collab/src/db/user.rs similarity index 93% rename from crates/collab/src/db2/user.rs rename to crates/collab/src/db/user.rs index f6bac9dc77d8dd92ce9353019a610a76a83528ae..b6e096f667c2e858635cb8a1b53e2c505357bd23 100644 --- a/crates/collab/src/db2/user.rs +++ b/crates/collab/src/db/user.rs @@ -1,7 +1,8 @@ use super::UserId; use sea_orm::entity::prelude::*; +use serde::Serialize; -#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel)] +#[derive(Clone, Debug, Default, PartialEq, Eq, DeriveEntityModel, Serialize)] #[sea_orm(table_name = "users")] pub struct Model { #[sea_orm(primary_key)] @@ -12,6 +13,7 @@ pub struct Model { pub admin: bool, pub invite_code: Option, pub invite_count: i32, + pub inviter_id: Option, pub connected_once: bool, pub metrics_id: Uuid, } diff --git a/crates/collab/src/db2/worktree.rs b/crates/collab/src/db/worktree.rs similarity index 100% rename from crates/collab/src/db2/worktree.rs rename to crates/collab/src/db/worktree.rs diff --git a/crates/collab/src/db2.rs b/crates/collab/src/db2.rs deleted file mode 100644 index 3aa21c60593aaf4a60189076b7f298821a64e7da..0000000000000000000000000000000000000000 --- a/crates/collab/src/db2.rs +++ /dev/null @@ -1,1416 +0,0 @@ -mod access_token; -mod contact; -mod project; -mod project_collaborator; -mod room; -mod room_participant; -mod signup; -#[cfg(test)] -mod tests; -mod user; -mod worktree; - -use crate::{Error, Result}; -use anyhow::anyhow; -use collections::HashMap; -use dashmap::DashMap; -use futures::StreamExt; -use hyper::StatusCode; -use rpc::{proto, ConnectionId}; -use sea_orm::{ - entity::prelude::*, ConnectOptions, DatabaseConnection, DatabaseTransaction, DbErr, - TransactionTrait, -}; -use sea_orm::{ - ActiveValue, ConnectionTrait, DatabaseBackend, FromQueryResult, IntoActiveModel, JoinType, - QueryOrder, QuerySelect, Statement, -}; -use sea_query::{Alias, Expr, OnConflict, Query}; -use serde::{Deserialize, Serialize}; -use sqlx::migrate::{Migrate, Migration, MigrationSource}; -use sqlx::Connection; -use std::ops::{Deref, DerefMut}; -use std::path::Path; -use std::time::Duration; -use std::{future::Future, marker::PhantomData, rc::Rc, sync::Arc}; -use tokio::sync::{Mutex, OwnedMutexGuard}; - -pub use contact::Contact; -pub use signup::{Invite, NewSignup, WaitlistSummary}; -pub use user::Model as User; - -pub struct Database { - options: ConnectOptions, - pool: DatabaseConnection, - rooms: DashMap>>, - #[cfg(test)] - background: Option>, - #[cfg(test)] - runtime: Option, -} - -impl Database { - pub async fn new(options: ConnectOptions) -> Result { - Ok(Self { - options: options.clone(), - pool: sea_orm::Database::connect(options).await?, - rooms: DashMap::with_capacity(16384), - #[cfg(test)] - background: None, - #[cfg(test)] - runtime: None, - }) - } - - pub async fn migrate( - &self, - migrations_path: &Path, - ignore_checksum_mismatch: bool, - ) -> anyhow::Result> { - let migrations = MigrationSource::resolve(migrations_path) - .await - .map_err(|err| anyhow!("failed to load migrations: {err:?}"))?; - - let mut connection = sqlx::AnyConnection::connect(self.options.get_url()).await?; - - connection.ensure_migrations_table().await?; - let applied_migrations: HashMap<_, _> = connection - .list_applied_migrations() - .await? - .into_iter() - .map(|m| (m.version, m)) - .collect(); - - let mut new_migrations = Vec::new(); - for migration in migrations { - match applied_migrations.get(&migration.version) { - Some(applied_migration) => { - if migration.checksum != applied_migration.checksum && !ignore_checksum_mismatch - { - Err(anyhow!( - "checksum mismatch for applied migration {}", - migration.description - ))?; - } - } - None => { - let elapsed = connection.apply(&migration).await?; - new_migrations.push((migration, elapsed)); - } - } - } - - Ok(new_migrations) - } - - // users - - pub async fn create_user( - &self, - email_address: &str, - admin: bool, - params: NewUserParams, - ) -> Result { - self.transact(|tx| async { - let user = user::Entity::insert(user::ActiveModel { - email_address: ActiveValue::set(Some(email_address.into())), - github_login: ActiveValue::set(params.github_login.clone()), - github_user_id: ActiveValue::set(Some(params.github_user_id)), - admin: ActiveValue::set(admin), - metrics_id: ActiveValue::set(Uuid::new_v4()), - ..Default::default() - }) - .on_conflict( - OnConflict::column(user::Column::GithubLogin) - .update_column(user::Column::GithubLogin) - .to_owned(), - ) - .exec_with_returning(&tx) - .await?; - - tx.commit().await?; - - Ok(NewUserResult { - user_id: user.id, - metrics_id: user.metrics_id.to_string(), - signup_device_id: None, - inviting_user_id: None, - }) - }) - .await - } - - pub async fn get_user_by_id(&self, id: UserId) -> Result> { - self.transact(|tx| async move { Ok(user::Entity::find_by_id(id).one(&tx).await?) }) - .await - } - - pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { - self.transact(|tx| async { - let tx = tx; - Ok(user::Entity::find() - .filter(user::Column::Id.is_in(ids.iter().copied())) - .all(&tx) - .await?) - }) - .await - } - - pub async fn get_user_by_github_account( - &self, - github_login: &str, - github_user_id: Option, - ) -> Result> { - self.transact(|tx| async { - let tx = tx; - if let Some(github_user_id) = github_user_id { - if let Some(user_by_github_user_id) = user::Entity::find() - .filter(user::Column::GithubUserId.eq(github_user_id)) - .one(&tx) - .await? - { - let mut user_by_github_user_id = user_by_github_user_id.into_active_model(); - user_by_github_user_id.github_login = ActiveValue::set(github_login.into()); - Ok(Some(user_by_github_user_id.update(&tx).await?)) - } else if let Some(user_by_github_login) = user::Entity::find() - .filter(user::Column::GithubLogin.eq(github_login)) - .one(&tx) - .await? - { - let mut user_by_github_login = user_by_github_login.into_active_model(); - user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id)); - Ok(Some(user_by_github_login.update(&tx).await?)) - } else { - Ok(None) - } - } else { - Ok(user::Entity::find() - .filter(user::Column::GithubLogin.eq(github_login)) - .one(&tx) - .await?) - } - }) - .await - } - - pub async fn get_user_metrics_id(&self, id: UserId) -> Result { - #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] - enum QueryAs { - MetricsId, - } - - self.transact(|tx| async move { - let metrics_id: Uuid = user::Entity::find_by_id(id) - .select_only() - .column(user::Column::MetricsId) - .into_values::<_, QueryAs>() - .one(&tx) - .await? - .ok_or_else(|| anyhow!("could not find user"))?; - Ok(metrics_id.to_string()) - }) - .await - } - - // contacts - - pub async fn get_contacts(&self, user_id: UserId) -> Result> { - #[derive(Debug, FromQueryResult)] - struct ContactWithUserBusyStatuses { - user_id_a: UserId, - user_id_b: UserId, - a_to_b: bool, - accepted: bool, - should_notify: bool, - user_a_busy: bool, - user_b_busy: bool, - } - - self.transact(|tx| async move { - let user_a_participant = Alias::new("user_a_participant"); - let user_b_participant = Alias::new("user_b_participant"); - let mut db_contacts = contact::Entity::find() - .column_as( - Expr::tbl(user_a_participant.clone(), room_participant::Column::Id) - .is_not_null(), - "user_a_busy", - ) - .column_as( - Expr::tbl(user_b_participant.clone(), room_participant::Column::Id) - .is_not_null(), - "user_b_busy", - ) - .filter( - contact::Column::UserIdA - .eq(user_id) - .or(contact::Column::UserIdB.eq(user_id)), - ) - .join_as( - JoinType::LeftJoin, - contact::Relation::UserARoomParticipant.def(), - user_a_participant, - ) - .join_as( - JoinType::LeftJoin, - contact::Relation::UserBRoomParticipant.def(), - user_b_participant, - ) - .into_model::() - .stream(&tx) - .await?; - - let mut contacts = Vec::new(); - while let Some(db_contact) = db_contacts.next().await { - let db_contact = db_contact?; - if db_contact.user_id_a == user_id { - if db_contact.accepted { - contacts.push(Contact::Accepted { - user_id: db_contact.user_id_b, - should_notify: db_contact.should_notify && db_contact.a_to_b, - busy: db_contact.user_b_busy, - }); - } else if db_contact.a_to_b { - contacts.push(Contact::Outgoing { - user_id: db_contact.user_id_b, - }) - } else { - contacts.push(Contact::Incoming { - user_id: db_contact.user_id_b, - should_notify: db_contact.should_notify, - }); - } - } else if db_contact.accepted { - contacts.push(Contact::Accepted { - user_id: db_contact.user_id_a, - should_notify: db_contact.should_notify && !db_contact.a_to_b, - busy: db_contact.user_a_busy, - }); - } else if db_contact.a_to_b { - contacts.push(Contact::Incoming { - user_id: db_contact.user_id_a, - should_notify: db_contact.should_notify, - }); - } else { - contacts.push(Contact::Outgoing { - user_id: db_contact.user_id_a, - }); - } - } - - contacts.sort_unstable_by_key(|contact| contact.user_id()); - - Ok(contacts) - }) - .await - } - - pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result { - self.transact(|tx| async move { - let (id_a, id_b) = if user_id_1 < user_id_2 { - (user_id_1, user_id_2) - } else { - (user_id_2, user_id_1) - }; - - Ok(contact::Entity::find() - .filter( - contact::Column::UserIdA - .eq(id_a) - .and(contact::Column::UserIdB.eq(id_b)) - .and(contact::Column::Accepted.eq(true)), - ) - .one(&tx) - .await? - .is_some()) - }) - .await - } - - pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { - self.transact(|tx| async move { - let (id_a, id_b, a_to_b) = if sender_id < receiver_id { - (sender_id, receiver_id, true) - } else { - (receiver_id, sender_id, false) - }; - - let rows_affected = contact::Entity::insert(contact::ActiveModel { - user_id_a: ActiveValue::set(id_a), - user_id_b: ActiveValue::set(id_b), - a_to_b: ActiveValue::set(a_to_b), - accepted: ActiveValue::set(false), - should_notify: ActiveValue::set(true), - ..Default::default() - }) - .on_conflict( - OnConflict::columns([contact::Column::UserIdA, contact::Column::UserIdB]) - .values([ - (contact::Column::Accepted, true.into()), - (contact::Column::ShouldNotify, false.into()), - ]) - .action_and_where( - contact::Column::Accepted.eq(false).and( - contact::Column::AToB - .eq(a_to_b) - .and(contact::Column::UserIdA.eq(id_b)) - .or(contact::Column::AToB - .ne(a_to_b) - .and(contact::Column::UserIdA.eq(id_a))), - ), - ) - .to_owned(), - ) - .exec_without_returning(&tx) - .await?; - - if rows_affected == 1 { - tx.commit().await?; - Ok(()) - } else { - Err(anyhow!("contact already requested"))? - } - }) - .await - } - - pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> { - self.transact(|tx| async move { - let (id_a, id_b) = if responder_id < requester_id { - (responder_id, requester_id) - } else { - (requester_id, responder_id) - }; - - let result = contact::Entity::delete_many() - .filter( - contact::Column::UserIdA - .eq(id_a) - .and(contact::Column::UserIdB.eq(id_b)), - ) - .exec(&tx) - .await?; - - if result.rows_affected == 1 { - tx.commit().await?; - Ok(()) - } else { - Err(anyhow!("no such contact"))? - } - }) - .await - } - - pub async fn dismiss_contact_notification( - &self, - user_id: UserId, - contact_user_id: UserId, - ) -> Result<()> { - self.transact(|tx| async move { - let (id_a, id_b, a_to_b) = if user_id < contact_user_id { - (user_id, contact_user_id, true) - } else { - (contact_user_id, user_id, false) - }; - - let result = contact::Entity::update_many() - .set(contact::ActiveModel { - should_notify: ActiveValue::set(false), - ..Default::default() - }) - .filter( - contact::Column::UserIdA - .eq(id_a) - .and(contact::Column::UserIdB.eq(id_b)) - .and( - contact::Column::AToB - .eq(a_to_b) - .and(contact::Column::Accepted.eq(true)) - .or(contact::Column::AToB - .ne(a_to_b) - .and(contact::Column::Accepted.eq(false))), - ), - ) - .exec(&tx) - .await?; - if result.rows_affected == 0 { - Err(anyhow!("no such contact request"))? - } else { - tx.commit().await?; - Ok(()) - } - }) - .await - } - - pub async fn respond_to_contact_request( - &self, - responder_id: UserId, - requester_id: UserId, - accept: bool, - ) -> Result<()> { - self.transact(|tx| async move { - let (id_a, id_b, a_to_b) = if responder_id < requester_id { - (responder_id, requester_id, false) - } else { - (requester_id, responder_id, true) - }; - let rows_affected = if accept { - let result = contact::Entity::update_many() - .set(contact::ActiveModel { - accepted: ActiveValue::set(true), - should_notify: ActiveValue::set(true), - ..Default::default() - }) - .filter( - contact::Column::UserIdA - .eq(id_a) - .and(contact::Column::UserIdB.eq(id_b)) - .and(contact::Column::AToB.eq(a_to_b)), - ) - .exec(&tx) - .await?; - result.rows_affected - } else { - let result = contact::Entity::delete_many() - .filter( - contact::Column::UserIdA - .eq(id_a) - .and(contact::Column::UserIdB.eq(id_b)) - .and(contact::Column::AToB.eq(a_to_b)) - .and(contact::Column::Accepted.eq(false)), - ) - .exec(&tx) - .await?; - - result.rows_affected - }; - - if rows_affected == 1 { - tx.commit().await?; - Ok(()) - } else { - Err(anyhow!("no such contact request"))? - } - }) - .await - } - - pub fn fuzzy_like_string(string: &str) -> String { - let mut result = String::with_capacity(string.len() * 2 + 1); - for c in string.chars() { - if c.is_alphanumeric() { - result.push('%'); - result.push(c); - } - } - result.push('%'); - result - } - - pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result> { - self.transact(|tx| async { - let tx = tx; - let like_string = Self::fuzzy_like_string(name_query); - let query = " - SELECT users.* - FROM users - WHERE github_login ILIKE $1 - ORDER BY github_login <-> $2 - LIMIT $3 - "; - - Ok(user::Entity::find() - .from_raw_sql(Statement::from_sql_and_values( - self.pool.get_database_backend(), - query.into(), - vec![like_string.into(), name_query.into(), limit.into()], - )) - .all(&tx) - .await?) - }) - .await - } - - // signups - - pub async fn create_signup(&self, signup: NewSignup) -> Result<()> { - self.transact(|tx| async { - signup::ActiveModel { - email_address: ActiveValue::set(signup.email_address.clone()), - email_confirmation_code: ActiveValue::set(random_email_confirmation_code()), - email_confirmation_sent: ActiveValue::set(false), - platform_mac: ActiveValue::set(signup.platform_mac), - platform_windows: ActiveValue::set(signup.platform_windows), - platform_linux: ActiveValue::set(signup.platform_linux), - platform_unknown: ActiveValue::set(false), - editor_features: ActiveValue::set(Some(signup.editor_features.clone())), - programming_languages: ActiveValue::set(Some(signup.programming_languages.clone())), - device_id: ActiveValue::set(signup.device_id.clone()), - ..Default::default() - } - .insert(&tx) - .await?; - tx.commit().await?; - Ok(()) - }) - .await - } - - pub async fn get_waitlist_summary(&self) -> Result { - self.transact(|tx| async move { - let query = " - SELECT - COUNT(*) as count, - COALESCE(SUM(CASE WHEN platform_linux THEN 1 ELSE 0 END), 0) as linux_count, - COALESCE(SUM(CASE WHEN platform_mac THEN 1 ELSE 0 END), 0) as mac_count, - COALESCE(SUM(CASE WHEN platform_windows THEN 1 ELSE 0 END), 0) as windows_count, - COALESCE(SUM(CASE WHEN platform_unknown THEN 1 ELSE 0 END), 0) as unknown_count - FROM ( - SELECT * - FROM signups - WHERE - NOT email_confirmation_sent - ) AS unsent - "; - Ok( - WaitlistSummary::find_by_statement(Statement::from_sql_and_values( - self.pool.get_database_backend(), - query.into(), - vec![], - )) - .one(&tx) - .await? - .ok_or_else(|| anyhow!("invalid result"))?, - ) - }) - .await - } - - pub async fn record_sent_invites(&self, invites: &[Invite]) -> Result<()> { - let emails = invites - .iter() - .map(|s| s.email_address.as_str()) - .collect::>(); - self.transact(|tx| async { - signup::Entity::update_many() - .filter(signup::Column::EmailAddress.is_in(emails.iter().copied())) - .col_expr(signup::Column::EmailConfirmationSent, true.into()) - .exec(&tx) - .await?; - tx.commit().await?; - Ok(()) - }) - .await - } - - pub async fn get_unsent_invites(&self, count: usize) -> Result> { - self.transact(|tx| async move { - Ok(signup::Entity::find() - .select_only() - .column(signup::Column::EmailAddress) - .column(signup::Column::EmailConfirmationCode) - .filter( - signup::Column::EmailConfirmationSent.eq(false).and( - signup::Column::PlatformMac - .eq(true) - .or(signup::Column::PlatformUnknown.eq(true)), - ), - ) - .limit(count as u64) - .into_model() - .all(&tx) - .await?) - }) - .await - } - - // invite codes - - pub async fn create_invite_from_code( - &self, - code: &str, - email_address: &str, - device_id: Option<&str>, - ) -> Result { - self.transact(|tx| async move { - let existing_user = user::Entity::find() - .filter(user::Column::EmailAddress.eq(email_address)) - .one(&tx) - .await?; - - if existing_user.is_some() { - Err(anyhow!("email address is already in use"))?; - } - - let inviter = match user::Entity::find() - .filter(user::Column::InviteCode.eq(code)) - .one(&tx) - .await? - { - Some(inviter) => inviter, - None => { - return Err(Error::Http( - StatusCode::NOT_FOUND, - "invite code not found".to_string(), - ))? - } - }; - - if inviter.invite_count == 0 { - Err(Error::Http( - StatusCode::UNAUTHORIZED, - "no invites remaining".to_string(), - ))?; - } - - let signup = signup::Entity::insert(signup::ActiveModel { - email_address: ActiveValue::set(email_address.into()), - email_confirmation_code: ActiveValue::set(random_email_confirmation_code()), - email_confirmation_sent: ActiveValue::set(false), - inviting_user_id: ActiveValue::set(Some(inviter.id)), - platform_linux: ActiveValue::set(false), - platform_mac: ActiveValue::set(false), - platform_windows: ActiveValue::set(false), - platform_unknown: ActiveValue::set(true), - device_id: ActiveValue::set(device_id.map(|device_id| device_id.into())), - ..Default::default() - }) - .on_conflict( - OnConflict::column(signup::Column::EmailAddress) - .update_column(signup::Column::InvitingUserId) - .to_owned(), - ) - .exec_with_returning(&tx) - .await?; - tx.commit().await?; - - Ok(Invite { - email_address: signup.email_address, - email_confirmation_code: signup.email_confirmation_code, - }) - }) - .await - } - - pub async fn create_user_from_invite( - &self, - invite: &Invite, - user: NewUserParams, - ) -> Result> { - self.transact(|tx| async { - let tx = tx; - let signup = signup::Entity::find() - .filter( - signup::Column::EmailAddress - .eq(invite.email_address.as_str()) - .and( - signup::Column::EmailConfirmationCode - .eq(invite.email_confirmation_code.as_str()), - ), - ) - .one(&tx) - .await? - .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?; - - if signup.user_id.is_some() { - return Ok(None); - } - - let user = user::Entity::insert(user::ActiveModel { - email_address: ActiveValue::set(Some(invite.email_address.clone())), - github_login: ActiveValue::set(user.github_login.clone()), - github_user_id: ActiveValue::set(Some(user.github_user_id)), - admin: ActiveValue::set(false), - invite_count: ActiveValue::set(user.invite_count), - invite_code: ActiveValue::set(Some(random_invite_code())), - metrics_id: ActiveValue::set(Uuid::new_v4()), - ..Default::default() - }) - .on_conflict( - OnConflict::column(user::Column::GithubLogin) - .update_columns([ - user::Column::EmailAddress, - user::Column::GithubUserId, - user::Column::Admin, - ]) - .to_owned(), - ) - .exec_with_returning(&tx) - .await?; - - let mut signup = signup.into_active_model(); - signup.user_id = ActiveValue::set(Some(user.id)); - let signup = signup.update(&tx).await?; - - if let Some(inviting_user_id) = signup.inviting_user_id { - let result = user::Entity::update_many() - .filter( - user::Column::Id - .eq(inviting_user_id) - .and(user::Column::InviteCount.gt(0)), - ) - .col_expr( - user::Column::InviteCount, - Expr::col(user::Column::InviteCount).sub(1), - ) - .exec(&tx) - .await?; - - if result.rows_affected == 0 { - Err(Error::Http( - StatusCode::UNAUTHORIZED, - "no invites remaining".to_string(), - ))?; - } - - contact::Entity::insert(contact::ActiveModel { - user_id_a: ActiveValue::set(inviting_user_id), - user_id_b: ActiveValue::set(user.id), - a_to_b: ActiveValue::set(true), - should_notify: ActiveValue::set(true), - accepted: ActiveValue::set(true), - ..Default::default() - }) - .on_conflict(OnConflict::new().do_nothing().to_owned()) - .exec_without_returning(&tx) - .await?; - } - - tx.commit().await?; - Ok(Some(NewUserResult { - user_id: user.id, - metrics_id: user.metrics_id.to_string(), - inviting_user_id: signup.inviting_user_id, - signup_device_id: signup.device_id, - })) - }) - .await - } - - pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> { - self.transact(|tx| async move { - if count > 0 { - user::Entity::update_many() - .filter( - user::Column::Id - .eq(id) - .and(user::Column::InviteCode.is_null()), - ) - .col_expr(user::Column::InviteCode, random_invite_code().into()) - .exec(&tx) - .await?; - } - - user::Entity::update_many() - .filter(user::Column::Id.eq(id)) - .col_expr(user::Column::InviteCount, count.into()) - .exec(&tx) - .await?; - tx.commit().await?; - Ok(()) - }) - .await - } - - pub async fn get_invite_code_for_user(&self, id: UserId) -> Result> { - self.transact(|tx| async move { - match user::Entity::find_by_id(id).one(&tx).await? { - Some(user) if user.invite_code.is_some() => { - Ok(Some((user.invite_code.unwrap(), user.invite_count as u32))) - } - _ => Ok(None), - } - }) - .await - } - - pub async fn get_user_for_invite_code(&self, code: &str) -> Result { - self.transact(|tx| async move { - user::Entity::find() - .filter(user::Column::InviteCode.eq(code)) - .one(&tx) - .await? - .ok_or_else(|| { - Error::Http( - StatusCode::NOT_FOUND, - "that invite code does not exist".to_string(), - ) - }) - }) - .await - } - - // projects - - pub async fn share_project( - &self, - room_id: RoomId, - connection_id: ConnectionId, - worktrees: &[proto::WorktreeMetadata], - ) -> Result> { - self.transact(|tx| async move { - let participant = room_participant::Entity::find() - .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0)) - .one(&tx) - .await? - .ok_or_else(|| anyhow!("could not find participant"))?; - if participant.room_id != room_id { - return Err(anyhow!("shared project on unexpected room"))?; - } - - let project = project::ActiveModel { - room_id: ActiveValue::set(participant.room_id), - host_user_id: ActiveValue::set(participant.user_id), - host_connection_id: ActiveValue::set(connection_id.0 as i32), - ..Default::default() - } - .insert(&tx) - .await?; - - worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel { - id: ActiveValue::set(worktree.id as i32), - project_id: ActiveValue::set(project.id), - abs_path: ActiveValue::set(worktree.abs_path.clone()), - root_name: ActiveValue::set(worktree.root_name.clone()), - visible: ActiveValue::set(worktree.visible), - scan_id: ActiveValue::set(0), - is_complete: ActiveValue::set(false), - })) - .exec(&tx) - .await?; - - project_collaborator::ActiveModel { - project_id: ActiveValue::set(project.id), - connection_id: ActiveValue::set(connection_id.0 as i32), - user_id: ActiveValue::set(participant.user_id), - replica_id: ActiveValue::set(0), - is_host: ActiveValue::set(true), - ..Default::default() - } - .insert(&tx) - .await?; - - let room = self.get_room(room_id, &tx).await?; - self.commit_room_transaction(room_id, tx, (project.id, room)) - .await - }) - .await - } - - async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result { - let db_room = room::Entity::find_by_id(room_id) - .one(tx) - .await? - .ok_or_else(|| anyhow!("could not find room"))?; - - let mut db_participants = db_room - .find_related(room_participant::Entity) - .stream(tx) - .await?; - let mut participants = HashMap::default(); - let mut pending_participants = Vec::new(); - while let Some(db_participant) = db_participants.next().await { - let db_participant = db_participant?; - if let Some(answering_connection_id) = db_participant.answering_connection_id { - let location = match ( - db_participant.location_kind, - db_participant.location_project_id, - ) { - (Some(0), Some(project_id)) => { - Some(proto::participant_location::Variant::SharedProject( - proto::participant_location::SharedProject { - id: project_id.to_proto(), - }, - )) - } - (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject( - Default::default(), - )), - _ => Some(proto::participant_location::Variant::External( - Default::default(), - )), - }; - participants.insert( - answering_connection_id, - proto::Participant { - user_id: db_participant.user_id.to_proto(), - peer_id: answering_connection_id as u32, - projects: Default::default(), - location: Some(proto::ParticipantLocation { variant: location }), - }, - ); - } else { - pending_participants.push(proto::PendingParticipant { - user_id: db_participant.user_id.to_proto(), - calling_user_id: db_participant.calling_user_id.to_proto(), - initial_project_id: db_participant.initial_project_id.map(|id| id.to_proto()), - }); - } - } - - let mut db_projects = db_room - .find_related(project::Entity) - .find_with_related(worktree::Entity) - .stream(tx) - .await?; - - while let Some(row) = db_projects.next().await { - let (db_project, db_worktree) = row?; - if let Some(participant) = participants.get_mut(&db_project.host_connection_id) { - let project = if let Some(project) = participant - .projects - .iter_mut() - .find(|project| project.id == db_project.id.to_proto()) - { - project - } else { - participant.projects.push(proto::ParticipantProject { - id: db_project.id.to_proto(), - worktree_root_names: Default::default(), - }); - participant.projects.last_mut().unwrap() - }; - - if let Some(db_worktree) = db_worktree { - project.worktree_root_names.push(db_worktree.root_name); - } - } - } - - Ok(proto::Room { - id: db_room.id.to_proto(), - live_kit_room: db_room.live_kit_room, - participants: participants.into_values().collect(), - pending_participants, - }) - } - - async fn commit_room_transaction( - &self, - room_id: RoomId, - tx: DatabaseTransaction, - data: T, - ) -> Result> { - let lock = self.rooms.entry(room_id).or_default().clone(); - let _guard = lock.lock_owned().await; - tx.commit().await?; - Ok(RoomGuard { - data, - _guard, - _not_send: PhantomData, - }) - } - - pub async fn create_access_token_hash( - &self, - user_id: UserId, - access_token_hash: &str, - max_access_token_count: usize, - ) -> Result<()> { - self.transact(|tx| async { - let tx = tx; - - access_token::ActiveModel { - user_id: ActiveValue::set(user_id), - hash: ActiveValue::set(access_token_hash.into()), - ..Default::default() - } - .insert(&tx) - .await?; - - access_token::Entity::delete_many() - .filter( - access_token::Column::Id.in_subquery( - Query::select() - .column(access_token::Column::Id) - .from(access_token::Entity) - .and_where(access_token::Column::UserId.eq(user_id)) - .order_by(access_token::Column::Id, sea_orm::Order::Desc) - .limit(10000) - .offset(max_access_token_count as u64) - .to_owned(), - ), - ) - .exec(&tx) - .await?; - tx.commit().await?; - Ok(()) - }) - .await - } - - pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result> { - #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] - enum QueryAs { - Hash, - } - - self.transact(|tx| async move { - Ok(access_token::Entity::find() - .select_only() - .column(access_token::Column::Hash) - .filter(access_token::Column::UserId.eq(user_id)) - .order_by_desc(access_token::Column::Id) - .into_values::<_, QueryAs>() - .all(&tx) - .await?) - }) - .await - } - - async fn transact(&self, f: F) -> Result - where - F: Send + Fn(DatabaseTransaction) -> Fut, - Fut: Send + Future>, - { - let body = async { - loop { - let tx = self.pool.begin().await?; - - // In Postgres, serializable transactions are opt-in - if let DatabaseBackend::Postgres = self.pool.get_database_backend() { - tx.execute(Statement::from_string( - DatabaseBackend::Postgres, - "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(), - )) - .await?; - } - - match f(tx).await { - Ok(result) => return Ok(result), - Err(error) => match error { - Error::Database2( - DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error)) - | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)), - ) if error - .as_database_error() - .and_then(|error| error.code()) - .as_deref() - == Some("40001") => - { - // Retry (don't break the loop) - } - error @ _ => return Err(error), - }, - } - } - }; - - #[cfg(test)] - { - if let Some(background) = self.background.as_ref() { - background.simulate_random_delay().await; - } - - self.runtime.as_ref().unwrap().block_on(body) - } - - #[cfg(not(test))] - { - body.await - } - } -} - -pub struct RoomGuard { - data: T, - _guard: OwnedMutexGuard<()>, - _not_send: PhantomData>, -} - -impl Deref for RoomGuard { - type Target = T; - - fn deref(&self) -> &T { - &self.data - } -} - -impl DerefMut for RoomGuard { - fn deref_mut(&mut self) -> &mut T { - &mut self.data - } -} - -#[derive(Debug, Serialize, Deserialize)] -pub struct NewUserParams { - pub github_login: String, - pub github_user_id: i32, - pub invite_count: i32, -} - -#[derive(Debug)] -pub struct NewUserResult { - pub user_id: UserId, - pub metrics_id: String, - pub inviting_user_id: Option, - pub signup_device_id: Option, -} - -fn random_invite_code() -> String { - nanoid::nanoid!(16) -} - -fn random_email_confirmation_code() -> String { - nanoid::nanoid!(64) -} - -macro_rules! id_type { - ($name:ident) => { - #[derive( - Clone, - Copy, - Debug, - Default, - PartialEq, - Eq, - PartialOrd, - Ord, - Hash, - sqlx::Type, - Serialize, - Deserialize, - )] - #[sqlx(transparent)] - #[serde(transparent)] - pub struct $name(pub i32); - - impl $name { - #[allow(unused)] - pub const MAX: Self = Self(i32::MAX); - - #[allow(unused)] - pub fn from_proto(value: u64) -> Self { - Self(value as i32) - } - - #[allow(unused)] - pub fn to_proto(self) -> u64 { - self.0 as u64 - } - } - - impl std::fmt::Display for $name { - fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { - self.0.fmt(f) - } - } - - impl From<$name> for sea_query::Value { - fn from(value: $name) -> Self { - sea_query::Value::Int(Some(value.0)) - } - } - - impl sea_orm::TryGetable for $name { - fn try_get( - res: &sea_orm::QueryResult, - pre: &str, - col: &str, - ) -> Result { - Ok(Self(i32::try_get(res, pre, col)?)) - } - } - - impl sea_query::ValueType for $name { - fn try_from(v: Value) -> Result { - match v { - Value::TinyInt(Some(int)) => { - Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) - } - Value::SmallInt(Some(int)) => { - Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) - } - Value::Int(Some(int)) => { - Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) - } - Value::BigInt(Some(int)) => { - Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) - } - Value::TinyUnsigned(Some(int)) => { - Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) - } - Value::SmallUnsigned(Some(int)) => { - Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) - } - Value::Unsigned(Some(int)) => { - Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) - } - Value::BigUnsigned(Some(int)) => { - Ok(Self(int.try_into().map_err(|_| sea_query::ValueTypeErr)?)) - } - _ => Err(sea_query::ValueTypeErr), - } - } - - fn type_name() -> String { - stringify!($name).into() - } - - fn array_type() -> sea_query::ArrayType { - sea_query::ArrayType::Int - } - - fn column_type() -> sea_query::ColumnType { - sea_query::ColumnType::Integer(None) - } - } - - impl sea_orm::TryFromU64 for $name { - fn try_from_u64(n: u64) -> Result { - Ok(Self(n.try_into().map_err(|_| { - DbErr::ConvertFromU64(concat!( - "error converting ", - stringify!($name), - " to u64" - )) - })?)) - } - } - - impl sea_query::Nullable for $name { - fn null() -> Value { - Value::Int(None) - } - } - }; -} - -id_type!(AccessTokenId); -id_type!(ContactId); -id_type!(UserId); -id_type!(RoomId); -id_type!(RoomParticipantId); -id_type!(ProjectId); -id_type!(ProjectCollaboratorId); -id_type!(SignupId); -id_type!(WorktreeId); - -#[cfg(test)] -pub use test::*; - -#[cfg(test)] -mod test { - use super::*; - use gpui::executor::Background; - use lazy_static::lazy_static; - use parking_lot::Mutex; - use rand::prelude::*; - use sea_orm::ConnectionTrait; - use sqlx::migrate::MigrateDatabase; - use std::sync::Arc; - - pub struct TestDb { - pub db: Option>, - pub connection: Option, - } - - impl TestDb { - pub fn sqlite(background: Arc) -> Self { - let url = format!("sqlite::memory:"); - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .unwrap(); - - let mut db = runtime.block_on(async { - let mut options = ConnectOptions::new(url); - options.max_connections(5); - let db = Database::new(options).await.unwrap(); - let sql = include_str!(concat!( - env!("CARGO_MANIFEST_DIR"), - "/migrations.sqlite/20221109000000_test_schema.sql" - )); - db.pool - .execute(sea_orm::Statement::from_string( - db.pool.get_database_backend(), - sql.into(), - )) - .await - .unwrap(); - db - }); - - db.background = Some(background); - db.runtime = Some(runtime); - - Self { - db: Some(Arc::new(db)), - connection: None, - } - } - - pub fn postgres(background: Arc) -> Self { - lazy_static! { - static ref LOCK: Mutex<()> = Mutex::new(()); - } - - let _guard = LOCK.lock(); - let mut rng = StdRng::from_entropy(); - let url = format!( - "postgres://postgres@localhost/zed-test-{}", - rng.gen::() - ); - let runtime = tokio::runtime::Builder::new_current_thread() - .enable_io() - .enable_time() - .build() - .unwrap(); - - let mut db = runtime.block_on(async { - sqlx::Postgres::create_database(&url) - .await - .expect("failed to create test db"); - let mut options = ConnectOptions::new(url); - options - .max_connections(5) - .idle_timeout(Duration::from_secs(0)); - let db = Database::new(options).await.unwrap(); - let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"); - db.migrate(Path::new(migrations_path), false).await.unwrap(); - db - }); - - db.background = Some(background); - db.runtime = Some(runtime); - - Self { - db: Some(Arc::new(db)), - connection: None, - } - } - - pub fn db(&self) -> &Arc { - self.db.as_ref().unwrap() - } - } - - impl Drop for TestDb { - fn drop(&mut self) { - let db = self.db.take().unwrap(); - if let DatabaseBackend::Postgres = db.pool.get_database_backend() { - db.runtime.as_ref().unwrap().block_on(async { - use util::ResultExt; - let query = " - SELECT pg_terminate_backend(pg_stat_activity.pid) - FROM pg_stat_activity - WHERE - pg_stat_activity.datname = current_database() AND - pid <> pg_backend_pid(); - "; - db.pool - .execute(sea_orm::Statement::from_string( - db.pool.get_database_backend(), - query.into(), - )) - .await - .log_err(); - sqlx::Postgres::drop_database(db.options.get_url()) - .await - .log_err(); - }) - } - } - } -} diff --git a/crates/collab/src/db2/tests.rs b/crates/collab/src/db2/tests.rs deleted file mode 100644 index b276bd5057b7282815a4c21eeea00fd691eecff5..0000000000000000000000000000000000000000 --- a/crates/collab/src/db2/tests.rs +++ /dev/null @@ -1,813 +0,0 @@ -use super::*; -use gpui::executor::{Background, Deterministic}; -use std::sync::Arc; - -macro_rules! test_both_dbs { - ($postgres_test_name:ident, $sqlite_test_name:ident, $db:ident, $body:block) => { - #[gpui::test] - async fn $postgres_test_name() { - let test_db = TestDb::postgres(Deterministic::new(0).build_background()); - let $db = test_db.db(); - $body - } - - #[gpui::test] - async fn $sqlite_test_name() { - let test_db = TestDb::sqlite(Deterministic::new(0).build_background()); - let $db = test_db.db(); - $body - } - }; -} - -test_both_dbs!( - test_get_users_by_ids_postgres, - test_get_users_by_ids_sqlite, - db, - { - let mut user_ids = Vec::new(); - let mut user_metric_ids = Vec::new(); - for i in 1..=4 { - let user = db - .create_user( - &format!("user{i}@example.com"), - false, - NewUserParams { - github_login: format!("user{i}"), - github_user_id: i, - invite_count: 0, - }, - ) - .await - .unwrap(); - user_ids.push(user.user_id); - user_metric_ids.push(user.metrics_id); - } - - assert_eq!( - db.get_users_by_ids(user_ids.clone()).await.unwrap(), - vec![ - User { - id: user_ids[0], - github_login: "user1".to_string(), - github_user_id: Some(1), - email_address: Some("user1@example.com".to_string()), - admin: false, - metrics_id: user_metric_ids[0].parse().unwrap(), - ..Default::default() - }, - User { - id: user_ids[1], - github_login: "user2".to_string(), - github_user_id: Some(2), - email_address: Some("user2@example.com".to_string()), - admin: false, - metrics_id: user_metric_ids[1].parse().unwrap(), - ..Default::default() - }, - User { - id: user_ids[2], - github_login: "user3".to_string(), - github_user_id: Some(3), - email_address: Some("user3@example.com".to_string()), - admin: false, - metrics_id: user_metric_ids[2].parse().unwrap(), - ..Default::default() - }, - User { - id: user_ids[3], - github_login: "user4".to_string(), - github_user_id: Some(4), - email_address: Some("user4@example.com".to_string()), - admin: false, - metrics_id: user_metric_ids[3].parse().unwrap(), - ..Default::default() - } - ] - ); - } -); - -test_both_dbs!( - test_get_user_by_github_account_postgres, - test_get_user_by_github_account_sqlite, - db, - { - let user_id1 = db - .create_user( - "user1@example.com", - false, - NewUserParams { - github_login: "login1".into(), - github_user_id: 101, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - let user_id2 = db - .create_user( - "user2@example.com", - false, - NewUserParams { - github_login: "login2".into(), - github_user_id: 102, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - - let user = db - .get_user_by_github_account("login1", None) - .await - .unwrap() - .unwrap(); - assert_eq!(user.id, user_id1); - assert_eq!(&user.github_login, "login1"); - assert_eq!(user.github_user_id, Some(101)); - - assert!(db - .get_user_by_github_account("non-existent-login", None) - .await - .unwrap() - .is_none()); - - let user = db - .get_user_by_github_account("the-new-login2", Some(102)) - .await - .unwrap() - .unwrap(); - assert_eq!(user.id, user_id2); - assert_eq!(&user.github_login, "the-new-login2"); - assert_eq!(user.github_user_id, Some(102)); - } -); - -test_both_dbs!( - test_create_access_tokens_postgres, - test_create_access_tokens_sqlite, - db, - { - let user = db - .create_user( - "u1@example.com", - false, - NewUserParams { - github_login: "u1".into(), - github_user_id: 1, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; - - db.create_access_token_hash(user, "h1", 3).await.unwrap(); - db.create_access_token_hash(user, "h2", 3).await.unwrap(); - assert_eq!( - db.get_access_token_hashes(user).await.unwrap(), - &["h2".to_string(), "h1".to_string()] - ); - - db.create_access_token_hash(user, "h3", 3).await.unwrap(); - assert_eq!( - db.get_access_token_hashes(user).await.unwrap(), - &["h3".to_string(), "h2".to_string(), "h1".to_string(),] - ); - - db.create_access_token_hash(user, "h4", 3).await.unwrap(); - assert_eq!( - db.get_access_token_hashes(user).await.unwrap(), - &["h4".to_string(), "h3".to_string(), "h2".to_string(),] - ); - - db.create_access_token_hash(user, "h5", 3).await.unwrap(); - assert_eq!( - db.get_access_token_hashes(user).await.unwrap(), - &["h5".to_string(), "h4".to_string(), "h3".to_string()] - ); - } -); - -test_both_dbs!(test_add_contacts_postgres, test_add_contacts_sqlite, db, { - let mut user_ids = Vec::new(); - for i in 0..3 { - user_ids.push( - db.create_user( - &format!("user{i}@example.com"), - false, - NewUserParams { - github_login: format!("user{i}"), - github_user_id: i, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id, - ); - } - - let user_1 = user_ids[0]; - let user_2 = user_ids[1]; - let user_3 = user_ids[2]; - - // User starts with no contacts - assert_eq!(db.get_contacts(user_1).await.unwrap(), &[]); - - // User requests a contact. Both users see the pending request. - db.send_contact_request(user_1, user_2).await.unwrap(); - assert!(!db.has_contact(user_1, user_2).await.unwrap()); - assert!(!db.has_contact(user_2, user_1).await.unwrap()); - assert_eq!( - db.get_contacts(user_1).await.unwrap(), - &[Contact::Outgoing { user_id: user_2 }], - ); - assert_eq!( - db.get_contacts(user_2).await.unwrap(), - &[Contact::Incoming { - user_id: user_1, - should_notify: true - }] - ); - - // User 2 dismisses the contact request notification without accepting or rejecting. - // We shouldn't notify them again. - db.dismiss_contact_notification(user_1, user_2) - .await - .unwrap_err(); - db.dismiss_contact_notification(user_2, user_1) - .await - .unwrap(); - assert_eq!( - db.get_contacts(user_2).await.unwrap(), - &[Contact::Incoming { - user_id: user_1, - should_notify: false - }] - ); - - // User can't accept their own contact request - db.respond_to_contact_request(user_1, user_2, true) - .await - .unwrap_err(); - - // User accepts a contact request. Both users see the contact. - db.respond_to_contact_request(user_2, user_1, true) - .await - .unwrap(); - assert_eq!( - db.get_contacts(user_1).await.unwrap(), - &[Contact::Accepted { - user_id: user_2, - should_notify: true, - busy: false, - }], - ); - assert!(db.has_contact(user_1, user_2).await.unwrap()); - assert!(db.has_contact(user_2, user_1).await.unwrap()); - assert_eq!( - db.get_contacts(user_2).await.unwrap(), - &[Contact::Accepted { - user_id: user_1, - should_notify: false, - busy: false, - }] - ); - - // Users cannot re-request existing contacts. - db.send_contact_request(user_1, user_2).await.unwrap_err(); - db.send_contact_request(user_2, user_1).await.unwrap_err(); - - // Users can't dismiss notifications of them accepting other users' requests. - db.dismiss_contact_notification(user_2, user_1) - .await - .unwrap_err(); - assert_eq!( - db.get_contacts(user_1).await.unwrap(), - &[Contact::Accepted { - user_id: user_2, - should_notify: true, - busy: false, - }] - ); - - // Users can dismiss notifications of other users accepting their requests. - db.dismiss_contact_notification(user_1, user_2) - .await - .unwrap(); - assert_eq!( - db.get_contacts(user_1).await.unwrap(), - &[Contact::Accepted { - user_id: user_2, - should_notify: false, - busy: false, - }] - ); - - // Users send each other concurrent contact requests and - // see that they are immediately accepted. - db.send_contact_request(user_1, user_3).await.unwrap(); - db.send_contact_request(user_3, user_1).await.unwrap(); - assert_eq!( - db.get_contacts(user_1).await.unwrap(), - &[ - Contact::Accepted { - user_id: user_2, - should_notify: false, - busy: false, - }, - Contact::Accepted { - user_id: user_3, - should_notify: false, - busy: false, - } - ] - ); - assert_eq!( - db.get_contacts(user_3).await.unwrap(), - &[Contact::Accepted { - user_id: user_1, - should_notify: false, - busy: false, - }], - ); - - // User declines a contact request. Both users see that it is gone. - db.send_contact_request(user_2, user_3).await.unwrap(); - db.respond_to_contact_request(user_3, user_2, false) - .await - .unwrap(); - assert!(!db.has_contact(user_2, user_3).await.unwrap()); - assert!(!db.has_contact(user_3, user_2).await.unwrap()); - assert_eq!( - db.get_contacts(user_2).await.unwrap(), - &[Contact::Accepted { - user_id: user_1, - should_notify: false, - busy: false, - }] - ); - assert_eq!( - db.get_contacts(user_3).await.unwrap(), - &[Contact::Accepted { - user_id: user_1, - should_notify: false, - busy: false, - }], - ); -}); - -test_both_dbs!(test_metrics_id_postgres, test_metrics_id_sqlite, db, { - let NewUserResult { - user_id: user1, - metrics_id: metrics_id1, - .. - } = db - .create_user( - "person1@example.com", - false, - NewUserParams { - github_login: "person1".into(), - github_user_id: 101, - invite_count: 5, - }, - ) - .await - .unwrap(); - let NewUserResult { - user_id: user2, - metrics_id: metrics_id2, - .. - } = db - .create_user( - "person2@example.com", - false, - NewUserParams { - github_login: "person2".into(), - github_user_id: 102, - invite_count: 5, - }, - ) - .await - .unwrap(); - - assert_eq!(db.get_user_metrics_id(user1).await.unwrap(), metrics_id1); - assert_eq!(db.get_user_metrics_id(user2).await.unwrap(), metrics_id2); - assert_eq!(metrics_id1.len(), 36); - assert_eq!(metrics_id2.len(), 36); - assert_ne!(metrics_id1, metrics_id2); -}); - -#[test] -fn test_fuzzy_like_string() { - assert_eq!(Database::fuzzy_like_string("abcd"), "%a%b%c%d%"); - assert_eq!(Database::fuzzy_like_string("x y"), "%x%y%"); - assert_eq!(Database::fuzzy_like_string(" z "), "%z%"); -} - -#[gpui::test] -async fn test_fuzzy_search_users() { - let test_db = TestDb::postgres(build_background_executor()); - let db = test_db.db(); - for (i, github_login) in [ - "California", - "colorado", - "oregon", - "washington", - "florida", - "delaware", - "rhode-island", - ] - .into_iter() - .enumerate() - { - db.create_user( - &format!("{github_login}@example.com"), - false, - NewUserParams { - github_login: github_login.into(), - github_user_id: i as i32, - invite_count: 0, - }, - ) - .await - .unwrap(); - } - - assert_eq!( - fuzzy_search_user_names(db, "clr").await, - &["colorado", "California"] - ); - assert_eq!( - fuzzy_search_user_names(db, "ro").await, - &["rhode-island", "colorado", "oregon"], - ); - - async fn fuzzy_search_user_names(db: &Database, query: &str) -> Vec { - db.fuzzy_search_users(query, 10) - .await - .unwrap() - .into_iter() - .map(|user| user.github_login) - .collect::>() - } -} - -#[gpui::test] -async fn test_invite_codes() { - let test_db = TestDb::postgres(build_background_executor()); - let db = test_db.db(); - - let NewUserResult { user_id: user1, .. } = db - .create_user( - "user1@example.com", - false, - NewUserParams { - github_login: "user1".into(), - github_user_id: 0, - invite_count: 0, - }, - ) - .await - .unwrap(); - - // Initially, user 1 has no invite code - assert_eq!(db.get_invite_code_for_user(user1).await.unwrap(), None); - - // Setting invite count to 0 when no code is assigned does not assign a new code - db.set_invite_count_for_user(user1, 0).await.unwrap(); - assert!(db.get_invite_code_for_user(user1).await.unwrap().is_none()); - - // User 1 creates an invite code that can be used twice. - db.set_invite_count_for_user(user1, 2).await.unwrap(); - let (invite_code, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); - assert_eq!(invite_count, 2); - - // User 2 redeems the invite code and becomes a contact of user 1. - let user2_invite = db - .create_invite_from_code(&invite_code, "user2@example.com", Some("user-2-device-id")) - .await - .unwrap(); - let NewUserResult { - user_id: user2, - inviting_user_id, - signup_device_id, - metrics_id, - } = db - .create_user_from_invite( - &user2_invite, - NewUserParams { - github_login: "user2".into(), - github_user_id: 2, - invite_count: 7, - }, - ) - .await - .unwrap() - .unwrap(); - let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); - assert_eq!(invite_count, 1); - assert_eq!(inviting_user_id, Some(user1)); - assert_eq!(signup_device_id.unwrap(), "user-2-device-id"); - assert_eq!(db.get_user_metrics_id(user2).await.unwrap(), metrics_id); - assert_eq!( - db.get_contacts(user1).await.unwrap(), - [Contact::Accepted { - user_id: user2, - should_notify: true, - busy: false, - }] - ); - assert_eq!( - db.get_contacts(user2).await.unwrap(), - [Contact::Accepted { - user_id: user1, - should_notify: false, - busy: false, - }] - ); - assert_eq!( - db.get_invite_code_for_user(user2).await.unwrap().unwrap().1, - 7 - ); - - // User 3 redeems the invite code and becomes a contact of user 1. - let user3_invite = db - .create_invite_from_code(&invite_code, "user3@example.com", None) - .await - .unwrap(); - let NewUserResult { - user_id: user3, - inviting_user_id, - signup_device_id, - .. - } = db - .create_user_from_invite( - &user3_invite, - NewUserParams { - github_login: "user-3".into(), - github_user_id: 3, - invite_count: 3, - }, - ) - .await - .unwrap() - .unwrap(); - let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); - assert_eq!(invite_count, 0); - assert_eq!(inviting_user_id, Some(user1)); - assert!(signup_device_id.is_none()); - assert_eq!( - db.get_contacts(user1).await.unwrap(), - [ - Contact::Accepted { - user_id: user2, - should_notify: true, - busy: false, - }, - Contact::Accepted { - user_id: user3, - should_notify: true, - busy: false, - } - ] - ); - assert_eq!( - db.get_contacts(user3).await.unwrap(), - [Contact::Accepted { - user_id: user1, - should_notify: false, - busy: false, - }] - ); - assert_eq!( - db.get_invite_code_for_user(user3).await.unwrap().unwrap().1, - 3 - ); - - // Trying to reedem the code for the third time results in an error. - db.create_invite_from_code(&invite_code, "user4@example.com", Some("user-4-device-id")) - .await - .unwrap_err(); - - // Invite count can be updated after the code has been created. - db.set_invite_count_for_user(user1, 2).await.unwrap(); - let (latest_code, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); - assert_eq!(latest_code, invite_code); // Invite code doesn't change when we increment above 0 - assert_eq!(invite_count, 2); - - // User 4 can now redeem the invite code and becomes a contact of user 1. - let user4_invite = db - .create_invite_from_code(&invite_code, "user4@example.com", Some("user-4-device-id")) - .await - .unwrap(); - let user4 = db - .create_user_from_invite( - &user4_invite, - NewUserParams { - github_login: "user-4".into(), - github_user_id: 4, - invite_count: 5, - }, - ) - .await - .unwrap() - .unwrap() - .user_id; - - let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); - assert_eq!(invite_count, 1); - assert_eq!( - db.get_contacts(user1).await.unwrap(), - [ - Contact::Accepted { - user_id: user2, - should_notify: true, - busy: false, - }, - Contact::Accepted { - user_id: user3, - should_notify: true, - busy: false, - }, - Contact::Accepted { - user_id: user4, - should_notify: true, - busy: false, - } - ] - ); - assert_eq!( - db.get_contacts(user4).await.unwrap(), - [Contact::Accepted { - user_id: user1, - should_notify: false, - busy: false, - }] - ); - assert_eq!( - db.get_invite_code_for_user(user4).await.unwrap().unwrap().1, - 5 - ); - - // An existing user cannot redeem invite codes. - db.create_invite_from_code(&invite_code, "user2@example.com", Some("user-2-device-id")) - .await - .unwrap_err(); - let (_, invite_count) = db.get_invite_code_for_user(user1).await.unwrap().unwrap(); - assert_eq!(invite_count, 1); -} - -#[gpui::test] -async fn test_signups() { - let test_db = TestDb::postgres(build_background_executor()); - let db = test_db.db(); - - // people sign up on the waitlist - for i in 0..8 { - db.create_signup(NewSignup { - email_address: format!("person-{i}@example.com"), - platform_mac: true, - platform_linux: i % 2 == 0, - platform_windows: i % 4 == 0, - editor_features: vec!["speed".into()], - programming_languages: vec!["rust".into(), "c".into()], - device_id: Some(format!("device_id_{i}")), - }) - .await - .unwrap(); - } - - assert_eq!( - db.get_waitlist_summary().await.unwrap(), - WaitlistSummary { - count: 8, - mac_count: 8, - linux_count: 4, - windows_count: 2, - unknown_count: 0, - } - ); - - // retrieve the next batch of signup emails to send - let signups_batch1 = db.get_unsent_invites(3).await.unwrap(); - let addresses = signups_batch1 - .iter() - .map(|s| &s.email_address) - .collect::>(); - assert_eq!( - addresses, - &[ - "person-0@example.com", - "person-1@example.com", - "person-2@example.com" - ] - ); - assert_ne!( - signups_batch1[0].email_confirmation_code, - signups_batch1[1].email_confirmation_code - ); - - // the waitlist isn't updated until we record that the emails - // were successfully sent. - let signups_batch = db.get_unsent_invites(3).await.unwrap(); - assert_eq!(signups_batch, signups_batch1); - - // once the emails go out, we can retrieve the next batch - // of signups. - db.record_sent_invites(&signups_batch1).await.unwrap(); - let signups_batch2 = db.get_unsent_invites(3).await.unwrap(); - let addresses = signups_batch2 - .iter() - .map(|s| &s.email_address) - .collect::>(); - assert_eq!( - addresses, - &[ - "person-3@example.com", - "person-4@example.com", - "person-5@example.com" - ] - ); - - // the sent invites are excluded from the summary. - assert_eq!( - db.get_waitlist_summary().await.unwrap(), - WaitlistSummary { - count: 5, - mac_count: 5, - linux_count: 2, - windows_count: 1, - unknown_count: 0, - } - ); - - // user completes the signup process by providing their - // github account. - let NewUserResult { - user_id, - inviting_user_id, - signup_device_id, - .. - } = db - .create_user_from_invite( - &Invite { - email_address: signups_batch1[0].email_address.clone(), - email_confirmation_code: signups_batch1[0].email_confirmation_code.clone(), - }, - NewUserParams { - github_login: "person-0".into(), - github_user_id: 0, - invite_count: 5, - }, - ) - .await - .unwrap() - .unwrap(); - let user = db.get_user_by_id(user_id).await.unwrap().unwrap(); - assert!(inviting_user_id.is_none()); - assert_eq!(user.github_login, "person-0"); - assert_eq!(user.email_address.as_deref(), Some("person-0@example.com")); - assert_eq!(user.invite_count, 5); - assert_eq!(signup_device_id.unwrap(), "device_id_0"); - - // cannot redeem the same signup again. - assert!(db - .create_user_from_invite( - &Invite { - email_address: signups_batch1[0].email_address.clone(), - email_confirmation_code: signups_batch1[0].email_confirmation_code.clone(), - }, - NewUserParams { - github_login: "some-other-github_account".into(), - github_user_id: 1, - invite_count: 5, - }, - ) - .await - .unwrap() - .is_none()); - - // cannot redeem a signup with the wrong confirmation code. - db.create_user_from_invite( - &Invite { - email_address: signups_batch1[1].email_address.clone(), - email_confirmation_code: "the-wrong-code".to_string(), - }, - NewUserParams { - github_login: "person-1".into(), - github_user_id: 2, - invite_count: 5, - }, - ) - .await - .unwrap_err(); -} - -fn build_background_executor() -> Arc { - Deterministic::new(0).build_background() -} diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index 93ff73fc838cf961b03dcd0ca5740a64625e2bae..225501c71d58fde0a3cbd9676c2ee9749dca3792 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -1,5 +1,5 @@ use crate::{ - db::{self, NewUserParams, SqliteTestDb as TestDb, UserId}, + db::{self, NewUserParams, TestDb, UserId}, rpc::{Executor, Server}, AppState, }; @@ -5665,7 +5665,7 @@ impl TestServer { async fn start(background: Arc) -> Self { static NEXT_LIVE_KIT_SERVER_ID: AtomicUsize = AtomicUsize::new(0); - let test_db = TestDb::new(background.clone()); + let test_db = TestDb::sqlite(background.clone()); let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst); let live_kit_server = live_kit_client::TestServer::create( format!("http://livekit.{}.test", live_kit_server_id), diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 8a2cdc980fbd80f62aa57a7534ab6d9ae3f61f41..4802fd82b41f5f0a069da7168a683cc7ab46e641 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -1,7 +1,6 @@ mod api; mod auth; mod db; -mod db2; mod env; mod rpc; @@ -11,7 +10,7 @@ mod integration_tests; use anyhow::anyhow; use axum::{routing::get, Router}; use collab::{Error, Result}; -use db::DefaultDb as Db; +use db::Database; use serde::Deserialize; use std::{ env::args, @@ -45,14 +44,16 @@ pub struct MigrateConfig { } pub struct AppState { - db: Arc, + db: Arc, live_kit_client: Option>, config: Config, } impl AppState { async fn new(config: Config) -> Result> { - let db = Db::new(&config.database_url, 5).await?; + let mut db_options = db::ConnectOptions::new(config.database_url.clone()); + db_options.max_connections(5); + let db = Database::new(db_options).await?; let live_kit_client = if let Some(((server, key), secret)) = config .live_kit_server .as_ref() @@ -92,7 +93,9 @@ async fn main() -> Result<()> { } Some("migrate") => { let config = envy::from_env::().expect("error loading config"); - let db = Db::new(&config.database_url, 5).await?; + let mut db_options = db::ConnectOptions::new(config.database_url.clone()); + db_options.max_connections(5); + let db = Database::new(db_options).await?; let migrations_path = config .migrations_path diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 07b98914808a6fcffc74710886a4d0c07d8e9a79..beefe54a9d6ee83b976c6d918c0f94efd87229e6 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -2,7 +2,7 @@ mod connection_pool; use crate::{ auth, - db::{self, DefaultDb, ProjectId, RoomId, User, UserId}, + db::{self, Database, ProjectId, RoomId, User, UserId}, AppState, Result, }; use anyhow::anyhow; @@ -128,10 +128,10 @@ impl fmt::Debug for Session { } } -struct DbHandle(Arc); +struct DbHandle(Arc); impl Deref for DbHandle { - type Target = DefaultDb; + type Target = Database; fn deref(&self) -> &Self::Target { self.0.as_ref() From db1d93576f8aea0364e52ddf1abdf92f74ea0dc1 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 1 Dec 2022 15:13:34 +0100 Subject: [PATCH 083/109] Go back to a compiling state, panicking on unimplemented db methods --- crates/collab/src/db.rs | 1261 +++++++++++++++++- crates/collab/src/db/project.rs | 12 + crates/collab/src/db/project_collaborator.rs | 4 +- crates/collab/src/db/user.rs | 8 + 4 files changed, 1237 insertions(+), 48 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index d89d041f2a832d17201e0c4f23d1d76aed32a5ef..c5f2f98d0b615d7352908d6c1fbb52ff0cb68aa8 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -12,7 +12,7 @@ mod worktree; use crate::{Error, Result}; use anyhow::anyhow; -use collections::HashMap; +use collections::{BTreeMap, HashMap, HashSet}; pub use contact::Contact; use dashmap::DashMap; use futures::StreamExt; @@ -255,6 +255,19 @@ impl Database { .await } + pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> { + self.transact(|tx| async move { + user::Entity::update_many() + .filter(user::Column::Id.eq(id)) + .col_expr(user::Column::ConnectedOnce, connected_once.into()) + .exec(&tx) + .await?; + tx.commit().await?; + Ok(()) + }) + .await + } + pub async fn destroy_user(&self, id: UserId) -> Result<()> { self.transact(|tx| async move { access_token::Entity::delete_many() @@ -360,6 +373,17 @@ impl Database { .await } + pub async fn is_user_busy(&self, user_id: UserId) -> Result { + self.transact(|tx| async move { + let participant = room_participant::Entity::find() + .filter(room_participant::Column::UserId.eq(user_id)) + .one(&tx) + .await?; + Ok(participant.is_some()) + }) + .await + } + pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result { self.transact(|tx| async move { let (id_a, id_b) = if user_id_1 < user_id_2 { @@ -896,63 +920,447 @@ impl Database { .await } - // projects + // rooms - pub async fn share_project( + pub async fn incoming_call_for_user( &self, - room_id: RoomId, - connection_id: ConnectionId, - worktrees: &[proto::WorktreeMetadata], - ) -> Result> { + user_id: UserId, + ) -> Result> { self.transact(|tx| async move { - let participant = room_participant::Entity::find() - .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0)) + let pending_participant = room_participant::Entity::find() + .filter( + room_participant::Column::UserId + .eq(user_id) + .and(room_participant::Column::AnsweringConnectionId.is_null()), + ) .one(&tx) - .await? - .ok_or_else(|| anyhow!("could not find participant"))?; - if participant.room_id != room_id { - return Err(anyhow!("shared project on unexpected room"))?; - } + .await?; - let project = project::ActiveModel { - room_id: ActiveValue::set(participant.room_id), - host_user_id: ActiveValue::set(participant.user_id), - host_connection_id: ActiveValue::set(connection_id.0 as i32), - ..Default::default() + if let Some(pending_participant) = pending_participant { + let room = self.get_room(pending_participant.room_id, &tx).await?; + Ok(Self::build_incoming_call(&room, user_id)) + } else { + Ok(None) } - .insert(&tx) - .await?; + }) + .await + } - worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel { - id: ActiveValue::set(worktree.id as i32), - project_id: ActiveValue::set(project.id), - abs_path: ActiveValue::set(worktree.abs_path.clone()), - root_name: ActiveValue::set(worktree.root_name.clone()), - visible: ActiveValue::set(worktree.visible), - scan_id: ActiveValue::set(0), - is_complete: ActiveValue::set(false), - })) - .exec(&tx) - .await?; + pub async fn create_room( + &self, + user_id: UserId, + connection_id: ConnectionId, + live_kit_room: &str, + ) -> Result> { + self.transact(|tx| async move { + todo!() + // let room_id = sqlx::query_scalar( + // " + // INSERT INTO rooms (live_kit_room) + // VALUES ($1) + // RETURNING id + // ", + // ) + // .bind(&live_kit_room) + // .fetch_one(&mut tx) + // .await + // .map(RoomId)?; + + // sqlx::query( + // " + // INSERT INTO room_participants (room_id, user_id, answering_connection_id, calling_user_id, calling_connection_id) + // VALUES ($1, $2, $3, $4, $5) + // ", + // ) + // .bind(room_id) + // .bind(user_id) + // .bind(connection_id.0 as i32) + // .bind(user_id) + // .bind(connection_id.0 as i32) + // .execute(&mut tx) + // .await?; + + // let room = self.get_room(room_id, &mut tx).await?; + // self.commit_room_transaction(room_id, tx, room).await + }) + .await + } - project_collaborator::ActiveModel { - project_id: ActiveValue::set(project.id), - connection_id: ActiveValue::set(connection_id.0 as i32), - user_id: ActiveValue::set(participant.user_id), - replica_id: ActiveValue::set(0), - is_host: ActiveValue::set(true), - ..Default::default() - } - .insert(&tx) - .await?; + pub async fn call( + &self, + room_id: RoomId, + calling_user_id: UserId, + calling_connection_id: ConnectionId, + called_user_id: UserId, + initial_project_id: Option, + ) -> Result> { + self.transact(|tx| async move { + todo!() + // sqlx::query( + // " + // INSERT INTO room_participants ( + // room_id, + // user_id, + // calling_user_id, + // calling_connection_id, + // initial_project_id + // ) + // VALUES ($1, $2, $3, $4, $5) + // ", + // ) + // .bind(room_id) + // .bind(called_user_id) + // .bind(calling_user_id) + // .bind(calling_connection_id.0 as i32) + // .bind(initial_project_id) + // .execute(&mut tx) + // .await?; + + // let room = self.get_room(room_id, &mut tx).await?; + // let incoming_call = Self::build_incoming_call(&room, called_user_id) + // .ok_or_else(|| anyhow!("failed to build incoming call"))?; + // self.commit_room_transaction(room_id, tx, (room, incoming_call)) + // .await + }) + .await + } - let room = self.get_room(room_id, &tx).await?; - self.commit_room_transaction(room_id, tx, (project.id, room)) - .await + pub async fn call_failed( + &self, + room_id: RoomId, + called_user_id: UserId, + ) -> Result> { + self.transact(|tx| async move { + todo!() + // sqlx::query( + // " + // DELETE FROM room_participants + // WHERE room_id = $1 AND user_id = $2 + // ", + // ) + // .bind(room_id) + // .bind(called_user_id) + // .execute(&mut tx) + // .await?; + + // let room = self.get_room(room_id, &mut tx).await?; + // self.commit_room_transaction(room_id, tx, room).await }) .await } + pub async fn decline_call( + &self, + expected_room_id: Option, + user_id: UserId, + ) -> Result> { + self.transact(|tx| async move { + todo!() + // let room_id = sqlx::query_scalar( + // " + // DELETE FROM room_participants + // WHERE user_id = $1 AND answering_connection_id IS NULL + // RETURNING room_id + // ", + // ) + // .bind(user_id) + // .fetch_one(&mut tx) + // .await?; + // if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) { + // return Err(anyhow!("declining call on unexpected room"))?; + // } + + // let room = self.get_room(room_id, &mut tx).await?; + // self.commit_room_transaction(room_id, tx, room).await + }) + .await + } + + pub async fn cancel_call( + &self, + expected_room_id: Option, + calling_connection_id: ConnectionId, + called_user_id: UserId, + ) -> Result> { + self.transact(|tx| async move { + todo!() + // let room_id = sqlx::query_scalar( + // " + // DELETE FROM room_participants + // WHERE user_id = $1 AND calling_connection_id = $2 AND answering_connection_id IS NULL + // RETURNING room_id + // ", + // ) + // .bind(called_user_id) + // .bind(calling_connection_id.0 as i32) + // .fetch_one(&mut tx) + // .await?; + // if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) { + // return Err(anyhow!("canceling call on unexpected room"))?; + // } + + // let room = self.get_room(room_id, &mut tx).await?; + // self.commit_room_transaction(room_id, tx, room).await + }) + .await + } + + pub async fn join_room( + &self, + room_id: RoomId, + user_id: UserId, + connection_id: ConnectionId, + ) -> Result> { + self.transact(|tx| async move { + todo!() + // sqlx::query( + // " + // UPDATE room_participants + // SET answering_connection_id = $1 + // WHERE room_id = $2 AND user_id = $3 + // RETURNING 1 + // ", + // ) + // .bind(connection_id.0 as i32) + // .bind(room_id) + // .bind(user_id) + // .fetch_one(&mut tx) + // .await?; + + // let room = self.get_room(room_id, &mut tx).await?; + // self.commit_room_transaction(room_id, tx, room).await + }) + .await + } + + pub async fn leave_room( + &self, + connection_id: ConnectionId, + ) -> Result>> { + self.transact(|tx| async move { + todo!() + // // Leave room. + // let room_id = sqlx::query_scalar::<_, RoomId>( + // " + // DELETE FROM room_participants + // WHERE answering_connection_id = $1 + // RETURNING room_id + // ", + // ) + // .bind(connection_id.0 as i32) + // .fetch_optional(&mut tx) + // .await?; + + // if let Some(room_id) = room_id { + // // Cancel pending calls initiated by the leaving user. + // let canceled_calls_to_user_ids: Vec = sqlx::query_scalar( + // " + // DELETE FROM room_participants + // WHERE calling_connection_id = $1 AND answering_connection_id IS NULL + // RETURNING user_id + // ", + // ) + // .bind(connection_id.0 as i32) + // .fetch_all(&mut tx) + // .await?; + + // let project_ids = sqlx::query_scalar::<_, ProjectId>( + // " + // SELECT project_id + // FROM project_collaborators + // WHERE connection_id = $1 + // ", + // ) + // .bind(connection_id.0 as i32) + // .fetch_all(&mut tx) + // .await?; + + // // Leave projects. + // let mut left_projects = HashMap::default(); + // if !project_ids.is_empty() { + // let mut params = "?,".repeat(project_ids.len()); + // params.pop(); + // let query = format!( + // " + // SELECT * + // FROM project_collaborators + // WHERE project_id IN ({params}) + // " + // ); + // let mut query = sqlx::query_as::<_, ProjectCollaborator>(&query); + // for project_id in project_ids { + // query = query.bind(project_id); + // } + + // let mut project_collaborators = query.fetch(&mut tx); + // while let Some(collaborator) = project_collaborators.next().await { + // let collaborator = collaborator?; + // let left_project = + // left_projects + // .entry(collaborator.project_id) + // .or_insert(LeftProject { + // id: collaborator.project_id, + // host_user_id: Default::default(), + // connection_ids: Default::default(), + // host_connection_id: Default::default(), + // }); + + // let collaborator_connection_id = + // ConnectionId(collaborator.connection_id as u32); + // if collaborator_connection_id != connection_id { + // left_project.connection_ids.push(collaborator_connection_id); + // } + + // if collaborator.is_host { + // left_project.host_user_id = collaborator.user_id; + // left_project.host_connection_id = + // ConnectionId(collaborator.connection_id as u32); + // } + // } + // } + // sqlx::query( + // " + // DELETE FROM project_collaborators + // WHERE connection_id = $1 + // ", + // ) + // .bind(connection_id.0 as i32) + // .execute(&mut tx) + // .await?; + + // // Unshare projects. + // sqlx::query( + // " + // DELETE FROM projects + // WHERE room_id = $1 AND host_connection_id = $2 + // ", + // ) + // .bind(room_id) + // .bind(connection_id.0 as i32) + // .execute(&mut tx) + // .await?; + + // let room = self.get_room(room_id, &mut tx).await?; + // Ok(Some( + // self.commit_room_transaction( + // room_id, + // tx, + // LeftRoom { + // room, + // left_projects, + // canceled_calls_to_user_ids, + // }, + // ) + // .await?, + // )) + // } else { + // Ok(None) + // } + }) + .await + } + + pub async fn update_room_participant_location( + &self, + room_id: RoomId, + connection_id: ConnectionId, + location: proto::ParticipantLocation, + ) -> Result> { + self.transact(|tx| async { + todo!() + // let mut tx = tx; + // let location_kind; + // let location_project_id; + // match location + // .variant + // .as_ref() + // .ok_or_else(|| anyhow!("invalid location"))? + // { + // proto::participant_location::Variant::SharedProject(project) => { + // location_kind = 0; + // location_project_id = Some(ProjectId::from_proto(project.id)); + // } + // proto::participant_location::Variant::UnsharedProject(_) => { + // location_kind = 1; + // location_project_id = None; + // } + // proto::participant_location::Variant::External(_) => { + // location_kind = 2; + // location_project_id = None; + // } + // } + + // sqlx::query( + // " + // UPDATE room_participants + // SET location_kind = $1, location_project_id = $2 + // WHERE room_id = $3 AND answering_connection_id = $4 + // RETURNING 1 + // ", + // ) + // .bind(location_kind) + // .bind(location_project_id) + // .bind(room_id) + // .bind(connection_id.0 as i32) + // .fetch_one(&mut tx) + // .await?; + + // let room = self.get_room(room_id, &mut tx).await?; + // self.commit_room_transaction(room_id, tx, room).await + }) + .await + } + + async fn get_guest_connection_ids( + &self, + project_id: ProjectId, + tx: &DatabaseTransaction, + ) -> Result> { + todo!() + // let mut guest_connection_ids = Vec::new(); + // let mut db_guest_connection_ids = sqlx::query_scalar::<_, i32>( + // " + // SELECT connection_id + // FROM project_collaborators + // WHERE project_id = $1 AND is_host = FALSE + // ", + // ) + // .bind(project_id) + // .fetch(tx); + // while let Some(connection_id) = db_guest_connection_ids.next().await { + // guest_connection_ids.push(ConnectionId(connection_id? as u32)); + // } + // Ok(guest_connection_ids) + } + + fn build_incoming_call( + room: &proto::Room, + called_user_id: UserId, + ) -> Option { + let pending_participant = room + .pending_participants + .iter() + .find(|participant| participant.user_id == called_user_id.to_proto())?; + + Some(proto::IncomingCall { + room_id: room.id, + calling_user_id: pending_participant.calling_user_id, + participant_user_ids: room + .participants + .iter() + .map(|participant| participant.user_id) + .collect(), + initial_project: room.participants.iter().find_map(|participant| { + let initial_project_id = pending_participant.initial_project_id?; + participant + .projects + .iter() + .find(|project| project.id == initial_project_id) + .cloned() + }), + }) + } + async fn get_room(&self, room_id: RoomId, tx: &DatabaseTransaction) -> Result { let db_room = room::Entity::find_by_id(room_id) .one(tx) @@ -1057,6 +1465,736 @@ impl Database { }) } + // projects + + pub async fn project_count_excluding_admins(&self) -> Result { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryAs { + Count, + } + + self.transact(|tx| async move { + Ok(project::Entity::find() + .select_only() + .column_as(project::Column::Id.count(), QueryAs::Count) + .inner_join(user::Entity) + .filter(user::Column::Admin.eq(false)) + .into_values::<_, QueryAs>() + .one(&tx) + .await? + .unwrap_or(0) as usize) + }) + .await + } + + pub async fn share_project( + &self, + room_id: RoomId, + connection_id: ConnectionId, + worktrees: &[proto::WorktreeMetadata], + ) -> Result> { + self.transact(|tx| async move { + let participant = room_participant::Entity::find() + .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0)) + .one(&tx) + .await? + .ok_or_else(|| anyhow!("could not find participant"))?; + if participant.room_id != room_id { + return Err(anyhow!("shared project on unexpected room"))?; + } + + let project = project::ActiveModel { + room_id: ActiveValue::set(participant.room_id), + host_user_id: ActiveValue::set(participant.user_id), + host_connection_id: ActiveValue::set(connection_id.0 as i32), + ..Default::default() + } + .insert(&tx) + .await?; + + worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel { + id: ActiveValue::set(worktree.id as i32), + project_id: ActiveValue::set(project.id), + abs_path: ActiveValue::set(worktree.abs_path.clone()), + root_name: ActiveValue::set(worktree.root_name.clone()), + visible: ActiveValue::set(worktree.visible), + scan_id: ActiveValue::set(0), + is_complete: ActiveValue::set(false), + })) + .exec(&tx) + .await?; + + project_collaborator::ActiveModel { + project_id: ActiveValue::set(project.id), + connection_id: ActiveValue::set(connection_id.0 as i32), + user_id: ActiveValue::set(participant.user_id), + replica_id: ActiveValue::set(ReplicaId(0)), + is_host: ActiveValue::set(true), + ..Default::default() + } + .insert(&tx) + .await?; + + let room = self.get_room(room_id, &tx).await?; + self.commit_room_transaction(room_id, tx, (project.id, room)) + .await + }) + .await + } + + pub async fn unshare_project( + &self, + project_id: ProjectId, + connection_id: ConnectionId, + ) -> Result)>> { + self.transact(|tx| async move { + todo!() + // let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; + // let room_id: RoomId = sqlx::query_scalar( + // " + // DELETE FROM projects + // WHERE id = $1 AND host_connection_id = $2 + // RETURNING room_id + // ", + // ) + // .bind(project_id) + // .bind(connection_id.0 as i32) + // .fetch_one(&mut tx) + // .await?; + // let room = self.get_room(room_id, &mut tx).await?; + // self.commit_room_transaction(room_id, tx, (room, guest_connection_ids)) + // .await + }) + .await + } + + pub async fn update_project( + &self, + project_id: ProjectId, + connection_id: ConnectionId, + worktrees: &[proto::WorktreeMetadata], + ) -> Result)>> { + self.transact(|tx| async move { + todo!() + // let room_id: RoomId = sqlx::query_scalar( + // " + // SELECT room_id + // FROM projects + // WHERE id = $1 AND host_connection_id = $2 + // ", + // ) + // .bind(project_id) + // .bind(connection_id.0 as i32) + // .fetch_one(&mut tx) + // .await?; + + // if !worktrees.is_empty() { + // let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len()); + // params.pop(); + // let query = format!( + // " + // INSERT INTO worktrees ( + // project_id, + // id, + // root_name, + // abs_path, + // visible, + // scan_id, + // is_complete + // ) + // VALUES {params} + // ON CONFLICT (project_id, id) DO UPDATE SET root_name = excluded.root_name + // " + // ); + + // let mut query = sqlx::query(&query); + // for worktree in worktrees { + // query = query + // .bind(project_id) + // .bind(worktree.id as i32) + // .bind(&worktree.root_name) + // .bind(&worktree.abs_path) + // .bind(worktree.visible) + // .bind(0) + // .bind(false) + // } + // query.execute(&mut tx).await?; + // } + + // let mut params = "?,".repeat(worktrees.len()); + // if !worktrees.is_empty() { + // params.pop(); + // } + // let query = format!( + // " + // DELETE FROM worktrees + // WHERE project_id = ? AND id NOT IN ({params}) + // ", + // ); + + // let mut query = sqlx::query(&query).bind(project_id); + // for worktree in worktrees { + // query = query.bind(WorktreeId(worktree.id as i32)); + // } + // query.execute(&mut tx).await?; + + // let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; + // let room = self.get_room(room_id, &mut tx).await?; + // self.commit_room_transaction(room_id, tx, (room, guest_connection_ids)) + // .await + }) + .await + } + + pub async fn update_worktree( + &self, + update: &proto::UpdateWorktree, + connection_id: ConnectionId, + ) -> Result>> { + self.transact(|tx| async move { + todo!() + // let project_id = ProjectId::from_proto(update.project_id); + // let worktree_id = WorktreeId::from_proto(update.worktree_id); + + // // Ensure the update comes from the host. + // let room_id: RoomId = sqlx::query_scalar( + // " + // SELECT room_id + // FROM projects + // WHERE id = $1 AND host_connection_id = $2 + // ", + // ) + // .bind(project_id) + // .bind(connection_id.0 as i32) + // .fetch_one(&mut tx) + // .await?; + + // // Update metadata. + // sqlx::query( + // " + // UPDATE worktrees + // SET + // root_name = $1, + // scan_id = $2, + // is_complete = $3, + // abs_path = $4 + // WHERE project_id = $5 AND id = $6 + // RETURNING 1 + // ", + // ) + // .bind(&update.root_name) + // .bind(update.scan_id as i64) + // .bind(update.is_last_update) + // .bind(&update.abs_path) + // .bind(project_id) + // .bind(worktree_id) + // .fetch_one(&mut tx) + // .await?; + + // if !update.updated_entries.is_empty() { + // let mut params = + // "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?),".repeat(update.updated_entries.len()); + // params.pop(); + + // let query = format!( + // " + // INSERT INTO worktree_entries ( + // project_id, + // worktree_id, + // id, + // is_dir, + // path, + // inode, + // mtime_seconds, + // mtime_nanos, + // is_symlink, + // is_ignored + // ) + // VALUES {params} + // ON CONFLICT (project_id, worktree_id, id) DO UPDATE SET + // is_dir = excluded.is_dir, + // path = excluded.path, + // inode = excluded.inode, + // mtime_seconds = excluded.mtime_seconds, + // mtime_nanos = excluded.mtime_nanos, + // is_symlink = excluded.is_symlink, + // is_ignored = excluded.is_ignored + // " + // ); + // let mut query = sqlx::query(&query); + // for entry in &update.updated_entries { + // let mtime = entry.mtime.clone().unwrap_or_default(); + // query = query + // .bind(project_id) + // .bind(worktree_id) + // .bind(entry.id as i64) + // .bind(entry.is_dir) + // .bind(&entry.path) + // .bind(entry.inode as i64) + // .bind(mtime.seconds as i64) + // .bind(mtime.nanos as i32) + // .bind(entry.is_symlink) + // .bind(entry.is_ignored); + // } + // query.execute(&mut tx).await?; + // } + + // if !update.removed_entries.is_empty() { + // let mut params = "?,".repeat(update.removed_entries.len()); + // params.pop(); + // let query = format!( + // " + // DELETE FROM worktree_entries + // WHERE project_id = ? AND worktree_id = ? AND id IN ({params}) + // " + // ); + + // let mut query = sqlx::query(&query).bind(project_id).bind(worktree_id); + // for entry_id in &update.removed_entries { + // query = query.bind(*entry_id as i64); + // } + // query.execute(&mut tx).await?; + // } + + // let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; + // self.commit_room_transaction(room_id, tx, connection_ids) + // .await + }) + .await + } + + pub async fn update_diagnostic_summary( + &self, + update: &proto::UpdateDiagnosticSummary, + connection_id: ConnectionId, + ) -> Result>> { + self.transact(|tx| async { + todo!() + // let project_id = ProjectId::from_proto(update.project_id); + // let worktree_id = WorktreeId::from_proto(update.worktree_id); + // let summary = update + // .summary + // .as_ref() + // .ok_or_else(|| anyhow!("invalid summary"))?; + + // // Ensure the update comes from the host. + // let room_id: RoomId = sqlx::query_scalar( + // " + // SELECT room_id + // FROM projects + // WHERE id = $1 AND host_connection_id = $2 + // ", + // ) + // .bind(project_id) + // .bind(connection_id.0 as i32) + // .fetch_one(&mut tx) + // .await?; + + // // Update summary. + // sqlx::query( + // " + // INSERT INTO worktree_diagnostic_summaries ( + // project_id, + // worktree_id, + // path, + // language_server_id, + // error_count, + // warning_count + // ) + // VALUES ($1, $2, $3, $4, $5, $6) + // ON CONFLICT (project_id, worktree_id, path) DO UPDATE SET + // language_server_id = excluded.language_server_id, + // error_count = excluded.error_count, + // warning_count = excluded.warning_count + // ", + // ) + // .bind(project_id) + // .bind(worktree_id) + // .bind(&summary.path) + // .bind(summary.language_server_id as i64) + // .bind(summary.error_count as i32) + // .bind(summary.warning_count as i32) + // .execute(&mut tx) + // .await?; + + // let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; + // self.commit_room_transaction(room_id, tx, connection_ids) + // .await + }) + .await + } + + pub async fn start_language_server( + &self, + update: &proto::StartLanguageServer, + connection_id: ConnectionId, + ) -> Result>> { + self.transact(|tx| async { + todo!() + // let project_id = ProjectId::from_proto(update.project_id); + // let server = update + // .server + // .as_ref() + // .ok_or_else(|| anyhow!("invalid language server"))?; + + // // Ensure the update comes from the host. + // let room_id: RoomId = sqlx::query_scalar( + // " + // SELECT room_id + // FROM projects + // WHERE id = $1 AND host_connection_id = $2 + // ", + // ) + // .bind(project_id) + // .bind(connection_id.0 as i32) + // .fetch_one(&mut tx) + // .await?; + + // // Add the newly-started language server. + // sqlx::query( + // " + // INSERT INTO language_servers (project_id, id, name) + // VALUES ($1, $2, $3) + // ON CONFLICT (project_id, id) DO UPDATE SET + // name = excluded.name + // ", + // ) + // .bind(project_id) + // .bind(server.id as i64) + // .bind(&server.name) + // .execute(&mut tx) + // .await?; + + // let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; + // self.commit_room_transaction(room_id, tx, connection_ids) + // .await + }) + .await + } + + pub async fn join_project( + &self, + project_id: ProjectId, + connection_id: ConnectionId, + ) -> Result> { + self.transact(|tx| async move { + todo!() + // let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>( + // " + // SELECT room_id, user_id + // FROM room_participants + // WHERE answering_connection_id = $1 + // ", + // ) + // .bind(connection_id.0 as i32) + // .fetch_one(&mut tx) + // .await?; + + // // Ensure project id was shared on this room. + // sqlx::query( + // " + // SELECT 1 + // FROM projects + // WHERE id = $1 AND room_id = $2 + // ", + // ) + // .bind(project_id) + // .bind(room_id) + // .fetch_one(&mut tx) + // .await?; + + // let mut collaborators = sqlx::query_as::<_, ProjectCollaborator>( + // " + // SELECT * + // FROM project_collaborators + // WHERE project_id = $1 + // ", + // ) + // .bind(project_id) + // .fetch_all(&mut tx) + // .await?; + // let replica_ids = collaborators + // .iter() + // .map(|c| c.replica_id) + // .collect::>(); + // let mut replica_id = ReplicaId(1); + // while replica_ids.contains(&replica_id) { + // replica_id.0 += 1; + // } + // let new_collaborator = ProjectCollaborator { + // project_id, + // connection_id: connection_id.0 as i32, + // user_id, + // replica_id, + // is_host: false, + // }; + + // sqlx::query( + // " + // INSERT INTO project_collaborators ( + // project_id, + // connection_id, + // user_id, + // replica_id, + // is_host + // ) + // VALUES ($1, $2, $3, $4, $5) + // ", + // ) + // .bind(new_collaborator.project_id) + // .bind(new_collaborator.connection_id) + // .bind(new_collaborator.user_id) + // .bind(new_collaborator.replica_id) + // .bind(new_collaborator.is_host) + // .execute(&mut tx) + // .await?; + // collaborators.push(new_collaborator); + + // let worktree_rows = sqlx::query_as::<_, WorktreeRow>( + // " + // SELECT * + // FROM worktrees + // WHERE project_id = $1 + // ", + // ) + // .bind(project_id) + // .fetch_all(&mut tx) + // .await?; + // let mut worktrees = worktree_rows + // .into_iter() + // .map(|worktree_row| { + // ( + // worktree_row.id, + // Worktree { + // id: worktree_row.id, + // abs_path: worktree_row.abs_path, + // root_name: worktree_row.root_name, + // visible: worktree_row.visible, + // entries: Default::default(), + // diagnostic_summaries: Default::default(), + // scan_id: worktree_row.scan_id as u64, + // is_complete: worktree_row.is_complete, + // }, + // ) + // }) + // .collect::>(); + + // // Populate worktree entries. + // { + // let mut entries = sqlx::query_as::<_, WorktreeEntry>( + // " + // SELECT * + // FROM worktree_entries + // WHERE project_id = $1 + // ", + // ) + // .bind(project_id) + // .fetch(&mut tx); + // while let Some(entry) = entries.next().await { + // let entry = entry?; + // if let Some(worktree) = worktrees.get_mut(&entry.worktree_id) { + // worktree.entries.push(proto::Entry { + // id: entry.id as u64, + // is_dir: entry.is_dir, + // path: entry.path, + // inode: entry.inode as u64, + // mtime: Some(proto::Timestamp { + // seconds: entry.mtime_seconds as u64, + // nanos: entry.mtime_nanos as u32, + // }), + // is_symlink: entry.is_symlink, + // is_ignored: entry.is_ignored, + // }); + // } + // } + // } + + // // Populate worktree diagnostic summaries. + // { + // let mut summaries = sqlx::query_as::<_, WorktreeDiagnosticSummary>( + // " + // SELECT * + // FROM worktree_diagnostic_summaries + // WHERE project_id = $1 + // ", + // ) + // .bind(project_id) + // .fetch(&mut tx); + // while let Some(summary) = summaries.next().await { + // let summary = summary?; + // if let Some(worktree) = worktrees.get_mut(&summary.worktree_id) { + // worktree + // .diagnostic_summaries + // .push(proto::DiagnosticSummary { + // path: summary.path, + // language_server_id: summary.language_server_id as u64, + // error_count: summary.error_count as u32, + // warning_count: summary.warning_count as u32, + // }); + // } + // } + // } + + // // Populate language servers. + // let language_servers = sqlx::query_as::<_, LanguageServer>( + // " + // SELECT * + // FROM language_servers + // WHERE project_id = $1 + // ", + // ) + // .bind(project_id) + // .fetch_all(&mut tx) + // .await?; + + // self.commit_room_transaction( + // room_id, + // tx, + // ( + // Project { + // collaborators, + // worktrees, + // language_servers: language_servers + // .into_iter() + // .map(|language_server| proto::LanguageServer { + // id: language_server.id.to_proto(), + // name: language_server.name, + // }) + // .collect(), + // }, + // replica_id as ReplicaId, + // ), + // ) + // .await + }) + .await + } + + pub async fn leave_project( + &self, + project_id: ProjectId, + connection_id: ConnectionId, + ) -> Result> { + self.transact(|tx| async move { + todo!() + // let result = sqlx::query( + // " + // DELETE FROM project_collaborators + // WHERE project_id = $1 AND connection_id = $2 + // ", + // ) + // .bind(project_id) + // .bind(connection_id.0 as i32) + // .execute(&mut tx) + // .await?; + + // if result.rows_affected() == 0 { + // Err(anyhow!("not a collaborator on this project"))?; + // } + + // let connection_ids = sqlx::query_scalar::<_, i32>( + // " + // SELECT connection_id + // FROM project_collaborators + // WHERE project_id = $1 + // ", + // ) + // .bind(project_id) + // .fetch_all(&mut tx) + // .await? + // .into_iter() + // .map(|id| ConnectionId(id as u32)) + // .collect(); + + // let (room_id, host_user_id, host_connection_id) = + // sqlx::query_as::<_, (RoomId, i32, i32)>( + // " + // SELECT room_id, host_user_id, host_connection_id + // FROM projects + // WHERE id = $1 + // ", + // ) + // .bind(project_id) + // .fetch_one(&mut tx) + // .await?; + + // self.commit_room_transaction( + // room_id, + // tx, + // LeftProject { + // id: project_id, + // host_user_id: UserId(host_user_id), + // host_connection_id: ConnectionId(host_connection_id as u32), + // connection_ids, + // }, + // ) + // .await + }) + .await + } + + pub async fn project_collaborators( + &self, + project_id: ProjectId, + connection_id: ConnectionId, + ) -> Result> { + self.transact(|tx| async move { + todo!() + // let collaborators = sqlx::query_as::<_, ProjectCollaborator>( + // " + // SELECT * + // FROM project_collaborators + // WHERE project_id = $1 + // ", + // ) + // .bind(project_id) + // .fetch_all(&mut tx) + // .await?; + + // if collaborators + // .iter() + // .any(|collaborator| collaborator.connection_id == connection_id.0 as i32) + // { + // Ok(collaborators) + // } else { + // Err(anyhow!("no such project"))? + // } + }) + .await + } + + pub async fn project_connection_ids( + &self, + project_id: ProjectId, + connection_id: ConnectionId, + ) -> Result> { + self.transact(|tx| async move { + todo!() + // let connection_ids = sqlx::query_scalar::<_, i32>( + // " + // SELECT connection_id + // FROM project_collaborators + // WHERE project_id = $1 + // ", + // ) + // .bind(project_id) + // .fetch_all(&mut tx) + // .await?; + + // if connection_ids.contains(&(connection_id.0 as i32)) { + // Ok(connection_ids + // .into_iter() + // .map(|connection_id| ConnectionId(connection_id as u32)) + // .collect()) + // } else { + // Err(anyhow!("no such project"))? + // } + }) + .await + } + + // access tokens + pub async fn create_access_token_hash( &self, user_id: UserId, @@ -1334,14 +2472,45 @@ macro_rules! id_type { id_type!(AccessTokenId); id_type!(ContactId); -id_type!(UserId); id_type!(RoomId); id_type!(RoomParticipantId); id_type!(ProjectId); id_type!(ProjectCollaboratorId); +id_type!(ReplicaId); id_type!(SignupId); +id_type!(UserId); id_type!(WorktreeId); +pub struct LeftRoom { + pub room: proto::Room, + pub left_projects: HashMap, + pub canceled_calls_to_user_ids: Vec, +} + +pub struct Project { + pub collaborators: Vec, + pub worktrees: BTreeMap, + pub language_servers: Vec, +} + +pub struct LeftProject { + pub id: ProjectId, + pub host_user_id: UserId, + pub host_connection_id: ConnectionId, + pub connection_ids: Vec, +} + +pub struct Worktree { + pub id: WorktreeId, + pub abs_path: String, + pub root_name: String, + pub visible: bool, + pub entries: Vec, + pub diagnostic_summaries: Vec, + pub scan_id: u64, + pub is_complete: bool, +} + #[cfg(test)] pub use test::*; diff --git a/crates/collab/src/db/project.rs b/crates/collab/src/db/project.rs index 21ee0b27d1350603f2bd5b7118cd853a49fee512..a9f0d1cb47d9b15c2cfa77a878f98c1456053385 100644 --- a/crates/collab/src/db/project.rs +++ b/crates/collab/src/db/project.rs @@ -13,6 +13,12 @@ pub struct Model { #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] pub enum Relation { + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::HostUserId", + to = "super::user::Column::Id" + )] + HostUser, #[sea_orm( belongs_to = "super::room::Entity", from = "Column::RoomId", @@ -23,6 +29,12 @@ pub enum Relation { Worktree, } +impl Related for Entity { + fn to() -> RelationDef { + Relation::HostUser.def() + } +} + impl Related for Entity { fn to() -> RelationDef { Relation::Room.def() diff --git a/crates/collab/src/db/project_collaborator.rs b/crates/collab/src/db/project_collaborator.rs index 3e572fe5d4fc94029bfa73c91648bfc44800aead..fb1d565e3a4bec8b8115e5b827293bf552f4a1aa 100644 --- a/crates/collab/src/db/project_collaborator.rs +++ b/crates/collab/src/db/project_collaborator.rs @@ -1,4 +1,4 @@ -use super::{ProjectCollaboratorId, ProjectId, UserId}; +use super::{ProjectCollaboratorId, ProjectId, ReplicaId, UserId}; use sea_orm::entity::prelude::*; #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] @@ -9,7 +9,7 @@ pub struct Model { pub project_id: ProjectId, pub connection_id: i32, pub user_id: UserId, - pub replica_id: i32, + pub replica_id: ReplicaId, pub is_host: bool, } diff --git a/crates/collab/src/db/user.rs b/crates/collab/src/db/user.rs index b6e096f667c2e858635cb8a1b53e2c505357bd23..c2b157bd0a758880fd6fe64b079fa8760b59df5c 100644 --- a/crates/collab/src/db/user.rs +++ b/crates/collab/src/db/user.rs @@ -24,6 +24,8 @@ pub enum Relation { AccessToken, #[sea_orm(has_one = "super::room_participant::Entity")] RoomParticipant, + #[sea_orm(has_many = "super::project::Entity")] + HostedProjects, } impl Related for Entity { @@ -38,4 +40,10 @@ impl Related for Entity { } } +impl Related for Entity { + fn to() -> RelationDef { + Relation::HostedProjects.def() + } +} + impl ActiveModelBehavior for ActiveModel {} From aebc6326a9960545fc164c17d1f19ecd0e9cf010 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 1 Dec 2022 15:22:12 +0100 Subject: [PATCH 084/109] Implement `Database::create_room` --- crates/collab/src/db.rs | 47 +++++++++++++++--------------------- crates/collab/src/db/room.rs | 2 +- 2 files changed, 21 insertions(+), 28 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index c5f2f98d0b615d7352908d6c1fbb52ff0cb68aa8..30049f2d05884d630bd74c8df2176528d514b585 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -953,35 +953,27 @@ impl Database { live_kit_room: &str, ) -> Result> { self.transact(|tx| async move { - todo!() - // let room_id = sqlx::query_scalar( - // " - // INSERT INTO rooms (live_kit_room) - // VALUES ($1) - // RETURNING id - // ", - // ) - // .bind(&live_kit_room) - // .fetch_one(&mut tx) - // .await - // .map(RoomId)?; + let room = room::ActiveModel { + live_kit_room: ActiveValue::set(live_kit_room.into()), + ..Default::default() + } + .insert(&tx) + .await?; + let room_id = room.id; - // sqlx::query( - // " - // INSERT INTO room_participants (room_id, user_id, answering_connection_id, calling_user_id, calling_connection_id) - // VALUES ($1, $2, $3, $4, $5) - // ", - // ) - // .bind(room_id) - // .bind(user_id) - // .bind(connection_id.0 as i32) - // .bind(user_id) - // .bind(connection_id.0 as i32) - // .execute(&mut tx) - // .await?; + room_participant::ActiveModel { + room_id: ActiveValue::set(room_id), + user_id: ActiveValue::set(user_id), + answering_connection_id: ActiveValue::set(Some(connection_id.0 as i32)), + calling_user_id: ActiveValue::set(user_id), + calling_connection_id: ActiveValue::set(connection_id.0 as i32), + ..Default::default() + } + .insert(&tx) + .await?; - // let room = self.get_room(room_id, &mut tx).await?; - // self.commit_room_transaction(room_id, tx, room).await + let room = self.get_room(room_id, &tx).await?; + self.commit_room_transaction(room_id, tx, room).await }) .await } @@ -1411,6 +1403,7 @@ impl Database { }); } } + drop(db_participants); let mut db_projects = db_room .find_related(project::Entity) diff --git a/crates/collab/src/db/room.rs b/crates/collab/src/db/room.rs index b57e612d46e32dced2be353e9d7c5bffe6d200bf..7dbf03a780adbd69c1d3b492e4bcf82557ae70ab 100644 --- a/crates/collab/src/db/room.rs +++ b/crates/collab/src/db/room.rs @@ -2,7 +2,7 @@ use super::RoomId; use sea_orm::entity::prelude::*; #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] -#[sea_orm(table_name = "room_participants")] +#[sea_orm(table_name = "rooms")] pub struct Model { #[sea_orm(primary_key)] pub id: RoomId, From 256e3e8e0fbcd03fcfe9e849d5252eb53318ed54 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 1 Dec 2022 16:06:01 +0100 Subject: [PATCH 085/109] Get basic calls working again with sea-orm --- crates/collab/src/db.rs | 471 ++++++++++++++++++++-------------------- 1 file changed, 232 insertions(+), 239 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 30049f2d05884d630bd74c8df2176528d514b585..bb1bff7ff85125ccd3658ef7d18288bf3b5a5f30 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -987,32 +987,22 @@ impl Database { initial_project_id: Option, ) -> Result> { self.transact(|tx| async move { - todo!() - // sqlx::query( - // " - // INSERT INTO room_participants ( - // room_id, - // user_id, - // calling_user_id, - // calling_connection_id, - // initial_project_id - // ) - // VALUES ($1, $2, $3, $4, $5) - // ", - // ) - // .bind(room_id) - // .bind(called_user_id) - // .bind(calling_user_id) - // .bind(calling_connection_id.0 as i32) - // .bind(initial_project_id) - // .execute(&mut tx) - // .await?; + room_participant::ActiveModel { + room_id: ActiveValue::set(room_id), + user_id: ActiveValue::set(called_user_id), + calling_user_id: ActiveValue::set(calling_user_id), + calling_connection_id: ActiveValue::set(calling_connection_id.0 as i32), + initial_project_id: ActiveValue::set(initial_project_id), + ..Default::default() + } + .insert(&tx) + .await?; - // let room = self.get_room(room_id, &mut tx).await?; - // let incoming_call = Self::build_incoming_call(&room, called_user_id) - // .ok_or_else(|| anyhow!("failed to build incoming call"))?; - // self.commit_room_transaction(room_id, tx, (room, incoming_call)) - // .await + let room = self.get_room(room_id, &tx).await?; + let incoming_call = Self::build_incoming_call(&room, called_user_id) + .ok_or_else(|| anyhow!("failed to build incoming call"))?; + self.commit_room_transaction(room_id, tx, (room, incoming_call)) + .await }) .await } @@ -1023,20 +1013,16 @@ impl Database { called_user_id: UserId, ) -> Result> { self.transact(|tx| async move { - todo!() - // sqlx::query( - // " - // DELETE FROM room_participants - // WHERE room_id = $1 AND user_id = $2 - // ", - // ) - // .bind(room_id) - // .bind(called_user_id) - // .execute(&mut tx) - // .await?; - - // let room = self.get_room(room_id, &mut tx).await?; - // self.commit_room_transaction(room_id, tx, room).await + room_participant::Entity::delete_many() + .filter( + room_participant::Column::RoomId + .eq(room_id) + .and(room_participant::Column::UserId.eq(called_user_id)), + ) + .exec(&tx) + .await?; + let room = self.get_room(room_id, &tx).await?; + self.commit_room_transaction(room_id, tx, room).await }) .await } @@ -1047,23 +1033,27 @@ impl Database { user_id: UserId, ) -> Result> { self.transact(|tx| async move { - todo!() - // let room_id = sqlx::query_scalar( - // " - // DELETE FROM room_participants - // WHERE user_id = $1 AND answering_connection_id IS NULL - // RETURNING room_id - // ", - // ) - // .bind(user_id) - // .fetch_one(&mut tx) - // .await?; - // if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) { - // return Err(anyhow!("declining call on unexpected room"))?; - // } + let participant = room_participant::Entity::find() + .filter( + room_participant::Column::UserId + .eq(user_id) + .and(room_participant::Column::AnsweringConnectionId.is_null()), + ) + .one(&tx) + .await? + .ok_or_else(|| anyhow!("could not decline call"))?; + let room_id = participant.room_id; - // let room = self.get_room(room_id, &mut tx).await?; - // self.commit_room_transaction(room_id, tx, room).await + if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) { + return Err(anyhow!("declining call on unexpected room"))?; + } + + room_participant::Entity::delete(participant.into_active_model()) + .exec(&tx) + .await?; + + let room = self.get_room(room_id, &tx).await?; + self.commit_room_transaction(room_id, tx, room).await }) .await } @@ -1075,24 +1065,30 @@ impl Database { called_user_id: UserId, ) -> Result> { self.transact(|tx| async move { - todo!() - // let room_id = sqlx::query_scalar( - // " - // DELETE FROM room_participants - // WHERE user_id = $1 AND calling_connection_id = $2 AND answering_connection_id IS NULL - // RETURNING room_id - // ", - // ) - // .bind(called_user_id) - // .bind(calling_connection_id.0 as i32) - // .fetch_one(&mut tx) - // .await?; - // if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) { - // return Err(anyhow!("canceling call on unexpected room"))?; - // } + let participant = room_participant::Entity::find() + .filter( + room_participant::Column::UserId + .eq(called_user_id) + .and( + room_participant::Column::CallingConnectionId + .eq(calling_connection_id.0 as i32), + ) + .and(room_participant::Column::AnsweringConnectionId.is_null()), + ) + .one(&tx) + .await? + .ok_or_else(|| anyhow!("could not cancel call"))?; + let room_id = participant.room_id; + if expected_room_id.map_or(false, |expected_room_id| expected_room_id != room_id) { + return Err(anyhow!("canceling call on unexpected room"))?; + } - // let room = self.get_room(room_id, &mut tx).await?; - // self.commit_room_transaction(room_id, tx, room).await + room_participant::Entity::delete(participant.into_active_model()) + .exec(&tx) + .await?; + + let room = self.get_room(room_id, &tx).await?; + self.commit_room_transaction(room_id, tx, room).await }) .await } @@ -1104,23 +1100,25 @@ impl Database { connection_id: ConnectionId, ) -> Result> { self.transact(|tx| async move { - todo!() - // sqlx::query( - // " - // UPDATE room_participants - // SET answering_connection_id = $1 - // WHERE room_id = $2 AND user_id = $3 - // RETURNING 1 - // ", - // ) - // .bind(connection_id.0 as i32) - // .bind(room_id) - // .bind(user_id) - // .fetch_one(&mut tx) - // .await?; - - // let room = self.get_room(room_id, &mut tx).await?; - // self.commit_room_transaction(room_id, tx, room).await + let result = room_participant::Entity::update_many() + .filter( + room_participant::Column::RoomId + .eq(room_id) + .and(room_participant::Column::UserId.eq(user_id)) + .and(room_participant::Column::AnsweringConnectionId.is_null()), + ) + .col_expr( + room_participant::Column::AnsweringConnectionId, + connection_id.0.into(), + ) + .exec(&tx) + .await?; + if result.rows_affected == 0 { + Err(anyhow!("room does not exist or was already joined"))? + } else { + let room = self.get_room(room_id, &tx).await?; + self.commit_room_transaction(room_id, tx, room).await + } }) .await } @@ -1130,124 +1128,117 @@ impl Database { connection_id: ConnectionId, ) -> Result>> { self.transact(|tx| async move { - todo!() - // // Leave room. - // let room_id = sqlx::query_scalar::<_, RoomId>( - // " - // DELETE FROM room_participants - // WHERE answering_connection_id = $1 - // RETURNING room_id - // ", - // ) - // .bind(connection_id.0 as i32) - // .fetch_optional(&mut tx) - // .await?; + let leaving_participant = room_participant::Entity::find() + .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0)) + .one(&tx) + .await?; - // if let Some(room_id) = room_id { - // // Cancel pending calls initiated by the leaving user. - // let canceled_calls_to_user_ids: Vec = sqlx::query_scalar( - // " - // DELETE FROM room_participants - // WHERE calling_connection_id = $1 AND answering_connection_id IS NULL - // RETURNING user_id - // ", - // ) - // .bind(connection_id.0 as i32) - // .fetch_all(&mut tx) - // .await?; + if let Some(leaving_participant) = leaving_participant { + // Leave room. + let room_id = leaving_participant.room_id; + room_participant::Entity::delete_by_id(leaving_participant.id) + .exec(&tx) + .await?; - // let project_ids = sqlx::query_scalar::<_, ProjectId>( - // " - // SELECT project_id - // FROM project_collaborators - // WHERE connection_id = $1 - // ", - // ) - // .bind(connection_id.0 as i32) - // .fetch_all(&mut tx) - // .await?; + // Cancel pending calls initiated by the leaving user. + let called_participants = room_participant::Entity::find() + .filter( + room_participant::Column::CallingConnectionId + .eq(connection_id.0) + .and(room_participant::Column::AnsweringConnectionId.is_null()), + ) + .all(&tx) + .await?; + room_participant::Entity::delete_many() + .filter( + room_participant::Column::Id + .is_in(called_participants.iter().map(|participant| participant.id)), + ) + .exec(&tx) + .await?; + let canceled_calls_to_user_ids = called_participants + .into_iter() + .map(|participant| participant.user_id) + .collect(); + + // Detect left projects. + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryProjectIds { + ProjectId, + } + let project_ids: Vec = project_collaborator::Entity::find() + .select_only() + .column_as( + project_collaborator::Column::ProjectId, + QueryProjectIds::ProjectId, + ) + .filter(project_collaborator::Column::ConnectionId.eq(connection_id.0)) + .into_values::<_, QueryProjectIds>() + .all(&tx) + .await?; + let mut left_projects = HashMap::default(); + let mut collaborators = project_collaborator::Entity::find() + .filter(project_collaborator::Column::ProjectId.is_in(project_ids)) + .stream(&tx) + .await?; + while let Some(collaborator) = collaborators.next().await { + let collaborator = collaborator?; + let left_project = + left_projects + .entry(collaborator.project_id) + .or_insert(LeftProject { + id: collaborator.project_id, + host_user_id: Default::default(), + connection_ids: Default::default(), + host_connection_id: Default::default(), + }); + + let collaborator_connection_id = + ConnectionId(collaborator.connection_id as u32); + if collaborator_connection_id != connection_id { + left_project.connection_ids.push(collaborator_connection_id); + } - // // Leave projects. - // let mut left_projects = HashMap::default(); - // if !project_ids.is_empty() { - // let mut params = "?,".repeat(project_ids.len()); - // params.pop(); - // let query = format!( - // " - // SELECT * - // FROM project_collaborators - // WHERE project_id IN ({params}) - // " - // ); - // let mut query = sqlx::query_as::<_, ProjectCollaborator>(&query); - // for project_id in project_ids { - // query = query.bind(project_id); - // } + if collaborator.is_host { + left_project.host_user_id = collaborator.user_id; + left_project.host_connection_id = + ConnectionId(collaborator.connection_id as u32); + } + } + drop(collaborators); - // let mut project_collaborators = query.fetch(&mut tx); - // while let Some(collaborator) = project_collaborators.next().await { - // let collaborator = collaborator?; - // let left_project = - // left_projects - // .entry(collaborator.project_id) - // .or_insert(LeftProject { - // id: collaborator.project_id, - // host_user_id: Default::default(), - // connection_ids: Default::default(), - // host_connection_id: Default::default(), - // }); - - // let collaborator_connection_id = - // ConnectionId(collaborator.connection_id as u32); - // if collaborator_connection_id != connection_id { - // left_project.connection_ids.push(collaborator_connection_id); - // } - - // if collaborator.is_host { - // left_project.host_user_id = collaborator.user_id; - // left_project.host_connection_id = - // ConnectionId(collaborator.connection_id as u32); - // } - // } - // } - // sqlx::query( - // " - // DELETE FROM project_collaborators - // WHERE connection_id = $1 - // ", - // ) - // .bind(connection_id.0 as i32) - // .execute(&mut tx) - // .await?; + // Leave projects. + project_collaborator::Entity::delete_many() + .filter(project_collaborator::Column::ConnectionId.eq(connection_id.0)) + .exec(&tx) + .await?; - // // Unshare projects. - // sqlx::query( - // " - // DELETE FROM projects - // WHERE room_id = $1 AND host_connection_id = $2 - // ", - // ) - // .bind(room_id) - // .bind(connection_id.0 as i32) - // .execute(&mut tx) - // .await?; + // Unshare projects. + project::Entity::delete_many() + .filter( + project::Column::RoomId + .eq(room_id) + .and(project::Column::HostConnectionId.eq(connection_id.0)), + ) + .exec(&tx) + .await?; - // let room = self.get_room(room_id, &mut tx).await?; - // Ok(Some( - // self.commit_room_transaction( - // room_id, - // tx, - // LeftRoom { - // room, - // left_projects, - // canceled_calls_to_user_ids, - // }, - // ) - // .await?, - // )) - // } else { - // Ok(None) - // } + let room = self.get_room(room_id, &tx).await?; + Ok(Some( + self.commit_room_transaction( + room_id, + tx, + LeftRoom { + room, + left_projects, + canceled_calls_to_user_ids, + }, + ) + .await?, + )) + } else { + Ok(None) + } }) .await } @@ -1259,46 +1250,48 @@ impl Database { location: proto::ParticipantLocation, ) -> Result> { self.transact(|tx| async { - todo!() - // let mut tx = tx; - // let location_kind; - // let location_project_id; - // match location - // .variant - // .as_ref() - // .ok_or_else(|| anyhow!("invalid location"))? - // { - // proto::participant_location::Variant::SharedProject(project) => { - // location_kind = 0; - // location_project_id = Some(ProjectId::from_proto(project.id)); - // } - // proto::participant_location::Variant::UnsharedProject(_) => { - // location_kind = 1; - // location_project_id = None; - // } - // proto::participant_location::Variant::External(_) => { - // location_kind = 2; - // location_project_id = None; - // } - // } + let mut tx = tx; + let location_kind; + let location_project_id; + match location + .variant + .as_ref() + .ok_or_else(|| anyhow!("invalid location"))? + { + proto::participant_location::Variant::SharedProject(project) => { + location_kind = 0; + location_project_id = Some(ProjectId::from_proto(project.id)); + } + proto::participant_location::Variant::UnsharedProject(_) => { + location_kind = 1; + location_project_id = None; + } + proto::participant_location::Variant::External(_) => { + location_kind = 2; + location_project_id = None; + } + } - // sqlx::query( - // " - // UPDATE room_participants - // SET location_kind = $1, location_project_id = $2 - // WHERE room_id = $3 AND answering_connection_id = $4 - // RETURNING 1 - // ", - // ) - // .bind(location_kind) - // .bind(location_project_id) - // .bind(room_id) - // .bind(connection_id.0 as i32) - // .fetch_one(&mut tx) - // .await?; + let result = room_participant::Entity::update_many() + .filter( + room_participant::Column::RoomId + .eq(room_id) + .and(room_participant::Column::AnsweringConnectionId.eq(connection_id.0)), + ) + .set(room_participant::ActiveModel { + location_kind: ActiveValue::set(Some(location_kind)), + location_project_id: ActiveValue::set(location_project_id), + ..Default::default() + }) + .exec(&tx) + .await?; - // let room = self.get_room(room_id, &mut tx).await?; - // self.commit_room_transaction(room_id, tx, room).await + if result.rows_affected == 1 { + let room = self.get_room(room_id, &mut tx).await?; + self.commit_room_transaction(room_id, tx, room).await + } else { + Err(anyhow!("could not update room participant location"))? + } }) .await } From 62624b81d88ae2661125f912a598d3feccddbb5b Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 1 Dec 2022 16:10:37 +0100 Subject: [PATCH 086/109] Avoid using `col_expr` whenever possible ...and use the more type-safe `::set`. --- crates/collab/src/db.rs | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index bb1bff7ff85125ccd3658ef7d18288bf3b5a5f30..dce217d955d749c46315991f560cb188b2b97fcb 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -246,7 +246,10 @@ impl Database { self.transact(|tx| async move { user::Entity::update_many() .filter(user::Column::Id.eq(id)) - .col_expr(user::Column::Admin, is_admin.into()) + .set(user::ActiveModel { + admin: ActiveValue::set(is_admin), + ..Default::default() + }) .exec(&tx) .await?; tx.commit().await?; @@ -259,7 +262,10 @@ impl Database { self.transact(|tx| async move { user::Entity::update_many() .filter(user::Column::Id.eq(id)) - .col_expr(user::Column::ConnectedOnce, connected_once.into()) + .set(user::ActiveModel { + connected_once: ActiveValue::set(connected_once), + ..Default::default() + }) .exec(&tx) .await?; tx.commit().await?; @@ -674,7 +680,10 @@ impl Database { self.transact(|tx| async { signup::Entity::update_many() .filter(signup::Column::EmailAddress.is_in(emails.iter().copied())) - .col_expr(signup::Column::EmailConfirmationSent, true.into()) + .set(signup::ActiveModel { + email_confirmation_sent: ActiveValue::set(true), + ..Default::default() + }) .exec(&tx) .await?; tx.commit().await?; @@ -876,14 +885,20 @@ impl Database { .eq(id) .and(user::Column::InviteCode.is_null()), ) - .col_expr(user::Column::InviteCode, random_invite_code().into()) + .set(user::ActiveModel { + invite_code: ActiveValue::set(Some(random_invite_code())), + ..Default::default() + }) .exec(&tx) .await?; } user::Entity::update_many() .filter(user::Column::Id.eq(id)) - .col_expr(user::Column::InviteCount, count.into()) + .set(user::ActiveModel { + invite_count: ActiveValue::set(count as i32), + ..Default::default() + }) .exec(&tx) .await?; tx.commit().await?; @@ -1107,10 +1122,10 @@ impl Database { .and(room_participant::Column::UserId.eq(user_id)) .and(room_participant::Column::AnsweringConnectionId.is_null()), ) - .col_expr( - room_participant::Column::AnsweringConnectionId, - connection_id.0.into(), - ) + .set(room_participant::ActiveModel { + answering_connection_id: ActiveValue::set(Some(connection_id.0 as i32)), + ..Default::default() + }) .exec(&tx) .await?; if result.rows_affected == 0 { From e3ac67784a8131f8c56212f201bddd57f4ea0a75 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 1 Dec 2022 16:23:29 +0100 Subject: [PATCH 087/109] Implement `Database::project_guest_connection_ids` --- crates/collab/src/db.rs | 54 ++++++++++++++++++++++++----------------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index dce217d955d749c46315991f560cb188b2b97fcb..96ca4e953055b29586c3c8cb57f678eaf87bcf56 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1311,28 +1311,6 @@ impl Database { .await } - async fn get_guest_connection_ids( - &self, - project_id: ProjectId, - tx: &DatabaseTransaction, - ) -> Result> { - todo!() - // let mut guest_connection_ids = Vec::new(); - // let mut db_guest_connection_ids = sqlx::query_scalar::<_, i32>( - // " - // SELECT connection_id - // FROM project_collaborators - // WHERE project_id = $1 AND is_host = FALSE - // ", - // ) - // .bind(project_id) - // .fetch(tx); - // while let Some(connection_id) = db_guest_connection_ids.next().await { - // guest_connection_ids.push(ConnectionId(connection_id? as u32)); - // } - // Ok(guest_connection_ids) - } - fn build_incoming_call( room: &proto::Room, called_user_id: UserId, @@ -2194,6 +2172,38 @@ impl Database { .await } + async fn project_guest_connection_ids( + &self, + project_id: ProjectId, + tx: &DatabaseTransaction, + ) -> Result> { + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryAs { + ConnectionId, + } + + let mut db_guest_connection_ids = project_collaborator::Entity::find() + .select_only() + .column_as( + project_collaborator::Column::ConnectionId, + QueryAs::ConnectionId, + ) + .filter( + project_collaborator::Column::ProjectId + .eq(project_id) + .and(project_collaborator::Column::IsHost.eq(false)), + ) + .into_values::() + .stream(tx) + .await?; + + let mut guest_connection_ids = Vec::new(); + while let Some(connection_id) = db_guest_connection_ids.next().await { + guest_connection_ids.push(ConnectionId(connection_id? as u32)); + } + Ok(guest_connection_ids) + } + // access tokens pub async fn create_access_token_hash( From 944d6554deb85dcb8ab14d1a05d4b0f77b707230 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 1 Dec 2022 16:26:13 +0100 Subject: [PATCH 088/109] Implement `Database::unshare_project` --- crates/collab/src/db.rs | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 96ca4e953055b29586c3c8cb57f678eaf87bcf56..fc377ff7ac7e4d003f5c53f82b377f255b053f72 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1527,22 +1527,23 @@ impl Database { connection_id: ConnectionId, ) -> Result)>> { self.transact(|tx| async move { - todo!() - // let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; - // let room_id: RoomId = sqlx::query_scalar( - // " - // DELETE FROM projects - // WHERE id = $1 AND host_connection_id = $2 - // RETURNING room_id - // ", - // ) - // .bind(project_id) - // .bind(connection_id.0 as i32) - // .fetch_one(&mut tx) - // .await?; - // let room = self.get_room(room_id, &mut tx).await?; - // self.commit_room_transaction(room_id, tx, (room, guest_connection_ids)) - // .await + let guest_connection_ids = self.project_guest_connection_ids(project_id, &tx).await?; + + let project = project::Entity::find_by_id(project_id) + .one(&tx) + .await? + .ok_or_else(|| anyhow!("project not found"))?; + if project.host_connection_id == connection_id.0 as i32 { + let room_id = project.room_id; + project::Entity::delete(project.into_active_model()) + .exec(&tx) + .await?; + let room = self.get_room(room_id, &tx).await?; + self.commit_room_transaction(room_id, tx, (room, guest_connection_ids)) + .await + } else { + Err(anyhow!("cannot unshare a project hosted by another user"))? + } }) .await } From cfdf0a57b8f4915018135a31309c53e1765bd8c3 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 1 Dec 2022 17:36:36 +0100 Subject: [PATCH 089/109] Implement `Database::update_project` --- crates/collab/src/db.rs | 172 ++++++++++++++++------------------------ 1 file changed, 69 insertions(+), 103 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index fc377ff7ac7e4d003f5c53f82b377f255b053f72..971a8cd612f659fb6233ae170ed89c6edf64ab41 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1555,73 +1555,40 @@ impl Database { worktrees: &[proto::WorktreeMetadata], ) -> Result)>> { self.transact(|tx| async move { - todo!() - // let room_id: RoomId = sqlx::query_scalar( - // " - // SELECT room_id - // FROM projects - // WHERE id = $1 AND host_connection_id = $2 - // ", - // ) - // .bind(project_id) - // .bind(connection_id.0 as i32) - // .fetch_one(&mut tx) - // .await?; - - // if !worktrees.is_empty() { - // let mut params = "(?, ?, ?, ?, ?, ?, ?),".repeat(worktrees.len()); - // params.pop(); - // let query = format!( - // " - // INSERT INTO worktrees ( - // project_id, - // id, - // root_name, - // abs_path, - // visible, - // scan_id, - // is_complete - // ) - // VALUES {params} - // ON CONFLICT (project_id, id) DO UPDATE SET root_name = excluded.root_name - // " - // ); - - // let mut query = sqlx::query(&query); - // for worktree in worktrees { - // query = query - // .bind(project_id) - // .bind(worktree.id as i32) - // .bind(&worktree.root_name) - // .bind(&worktree.abs_path) - // .bind(worktree.visible) - // .bind(0) - // .bind(false) - // } - // query.execute(&mut tx).await?; - // } - - // let mut params = "?,".repeat(worktrees.len()); - // if !worktrees.is_empty() { - // params.pop(); - // } - // let query = format!( - // " - // DELETE FROM worktrees - // WHERE project_id = ? AND id NOT IN ({params}) - // ", - // ); + let project = project::Entity::find_by_id(project_id) + .filter(project::Column::HostConnectionId.eq(connection_id.0)) + .one(&tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; - // let mut query = sqlx::query(&query).bind(project_id); - // for worktree in worktrees { - // query = query.bind(WorktreeId(worktree.id as i32)); - // } - // query.execute(&mut tx).await?; + worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel { + id: ActiveValue::set(worktree.id as i32), + project_id: ActiveValue::set(project.id), + abs_path: ActiveValue::set(worktree.abs_path.clone()), + root_name: ActiveValue::set(worktree.root_name.clone()), + visible: ActiveValue::set(worktree.visible), + scan_id: ActiveValue::set(0), + is_complete: ActiveValue::set(false), + })) + .exec(&tx) + .await?; + worktree::Entity::delete_many() + .filter( + worktree::Column::ProjectId.eq(project.id).and( + worktree::Column::Id.is_not_in( + worktrees + .iter() + .map(|worktree| WorktreeId(worktree.id as i32)), + ), + ), + ) + .exec(&tx) + .await?; - // let guest_connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; - // let room = self.get_room(room_id, &mut tx).await?; - // self.commit_room_transaction(room_id, tx, (room, guest_connection_ids)) - // .await + let guest_connection_ids = self.project_guest_connection_ids(project.id, &tx).await?; + let room = self.get_room(project.room_id, &tx).await?; + self.commit_room_transaction(project.room_id, tx, (room, guest_connection_ids)) + .await }) .await } @@ -2119,26 +2086,19 @@ impl Database { connection_id: ConnectionId, ) -> Result> { self.transact(|tx| async move { - todo!() - // let collaborators = sqlx::query_as::<_, ProjectCollaborator>( - // " - // SELECT * - // FROM project_collaborators - // WHERE project_id = $1 - // ", - // ) - // .bind(project_id) - // .fetch_all(&mut tx) - // .await?; + let collaborators = project_collaborator::Entity::find() + .filter(project_collaborator::Column::ProjectId.eq(project_id)) + .all(&tx) + .await?; - // if collaborators - // .iter() - // .any(|collaborator| collaborator.connection_id == connection_id.0 as i32) - // { - // Ok(collaborators) - // } else { - // Err(anyhow!("no such project"))? - // } + if collaborators + .iter() + .any(|collaborator| collaborator.connection_id == connection_id.0 as i32) + { + Ok(collaborators) + } else { + Err(anyhow!("no such project"))? + } }) .await } @@ -2149,26 +2109,32 @@ impl Database { connection_id: ConnectionId, ) -> Result> { self.transact(|tx| async move { - todo!() - // let connection_ids = sqlx::query_scalar::<_, i32>( - // " - // SELECT connection_id - // FROM project_collaborators - // WHERE project_id = $1 - // ", - // ) - // .bind(project_id) - // .fetch_all(&mut tx) - // .await?; + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryAs { + ConnectionId, + } - // if connection_ids.contains(&(connection_id.0 as i32)) { - // Ok(connection_ids - // .into_iter() - // .map(|connection_id| ConnectionId(connection_id as u32)) - // .collect()) - // } else { - // Err(anyhow!("no such project"))? - // } + let mut db_connection_ids = project_collaborator::Entity::find() + .select_only() + .column_as( + project_collaborator::Column::ConnectionId, + QueryAs::ConnectionId, + ) + .filter(project_collaborator::Column::ProjectId.eq(project_id)) + .into_values::() + .stream(&tx) + .await?; + + let mut connection_ids = HashSet::default(); + while let Some(connection_id) = db_connection_ids.next().await { + connection_ids.insert(ConnectionId(connection_id? as u32)); + } + + if connection_ids.contains(&connection_id) { + Ok(connection_ids) + } else { + Err(anyhow!("no such project"))? + } }) .await } From 29a4baf3469e38e1dd77aaad0f2b07e2a11830c9 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 1 Dec 2022 17:47:51 +0100 Subject: [PATCH 090/109] Replace i32 with u32 for database columns We never expect to return signed integers and so we shouldn't use a signed type. I think this was a limitation of sqlx. --- crates/collab/src/api.rs | 10 ++--- crates/collab/src/db.rs | 46 ++++++++++---------- crates/collab/src/db/project.rs | 2 +- crates/collab/src/db/project_collaborator.rs | 2 +- crates/collab/src/db/room_participant.rs | 6 +-- crates/collab/src/db/tests.rs | 2 +- crates/collab/src/db/user.rs | 4 +- crates/collab/src/db/worktree.rs | 7 ++- crates/collab/src/integration_tests.rs | 2 +- crates/collab/src/rpc.rs | 5 +-- 10 files changed, 41 insertions(+), 45 deletions(-) diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index bf183edf5440460cbd9f1d6043277266d346c8b5..a0554947917b135ac59f3ed72548bd23289ef045 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -76,7 +76,7 @@ pub async fn validate_api_token(req: Request, next: Next) -> impl IntoR #[derive(Debug, Deserialize)] struct AuthenticatedUserParams { - github_user_id: Option, + github_user_id: Option, github_login: String, } @@ -123,14 +123,14 @@ async fn get_users( #[derive(Deserialize, Debug)] struct CreateUserParams { - github_user_id: i32, + github_user_id: u32, github_login: String, email_address: String, email_confirmation_code: Option, #[serde(default)] admin: bool, #[serde(default)] - invite_count: i32, + invite_count: u32, } #[derive(Serialize, Debug)] @@ -208,7 +208,7 @@ struct UpdateUserParams { } async fn update_user( - Path(user_id): Path, + Path(user_id): Path, Json(params): Json, Extension(app): Extension>, Extension(rpc_server): Extension>, @@ -230,7 +230,7 @@ async fn update_user( } async fn destroy_user( - Path(user_id): Path, + Path(user_id): Path, Extension(app): Extension>, ) -> Result<()> { app.db.destroy_user(UserId(user_id)).await?; diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 971a8cd612f659fb6233ae170ed89c6edf64ab41..31ee381857af3db7b5f10c43411b1cde3b43359e 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -156,7 +156,7 @@ impl Database { pub async fn get_user_by_github_account( &self, github_login: &str, - github_user_id: Option, + github_user_id: Option, ) -> Result> { self.transact(|tx| async { let tx = tx; @@ -896,7 +896,7 @@ impl Database { user::Entity::update_many() .filter(user::Column::Id.eq(id)) .set(user::ActiveModel { - invite_count: ActiveValue::set(count as i32), + invite_count: ActiveValue::set(count), ..Default::default() }) .exec(&tx) @@ -979,9 +979,9 @@ impl Database { room_participant::ActiveModel { room_id: ActiveValue::set(room_id), user_id: ActiveValue::set(user_id), - answering_connection_id: ActiveValue::set(Some(connection_id.0 as i32)), + answering_connection_id: ActiveValue::set(Some(connection_id.0)), calling_user_id: ActiveValue::set(user_id), - calling_connection_id: ActiveValue::set(connection_id.0 as i32), + calling_connection_id: ActiveValue::set(connection_id.0), ..Default::default() } .insert(&tx) @@ -1006,7 +1006,7 @@ impl Database { room_id: ActiveValue::set(room_id), user_id: ActiveValue::set(called_user_id), calling_user_id: ActiveValue::set(calling_user_id), - calling_connection_id: ActiveValue::set(calling_connection_id.0 as i32), + calling_connection_id: ActiveValue::set(calling_connection_id.0), initial_project_id: ActiveValue::set(initial_project_id), ..Default::default() } @@ -1123,7 +1123,7 @@ impl Database { .and(room_participant::Column::AnsweringConnectionId.is_null()), ) .set(room_participant::ActiveModel { - answering_connection_id: ActiveValue::set(Some(connection_id.0 as i32)), + answering_connection_id: ActiveValue::set(Some(connection_id.0)), ..Default::default() }) .exec(&tx) @@ -1485,14 +1485,14 @@ impl Database { let project = project::ActiveModel { room_id: ActiveValue::set(participant.room_id), host_user_id: ActiveValue::set(participant.user_id), - host_connection_id: ActiveValue::set(connection_id.0 as i32), + host_connection_id: ActiveValue::set(connection_id.0), ..Default::default() } .insert(&tx) .await?; worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel { - id: ActiveValue::set(worktree.id as i32), + id: ActiveValue::set(WorktreeId(worktree.id as u32)), project_id: ActiveValue::set(project.id), abs_path: ActiveValue::set(worktree.abs_path.clone()), root_name: ActiveValue::set(worktree.root_name.clone()), @@ -1505,7 +1505,7 @@ impl Database { project_collaborator::ActiveModel { project_id: ActiveValue::set(project.id), - connection_id: ActiveValue::set(connection_id.0 as i32), + connection_id: ActiveValue::set(connection_id.0), user_id: ActiveValue::set(participant.user_id), replica_id: ActiveValue::set(ReplicaId(0)), is_host: ActiveValue::set(true), @@ -1533,7 +1533,7 @@ impl Database { .one(&tx) .await? .ok_or_else(|| anyhow!("project not found"))?; - if project.host_connection_id == connection_id.0 as i32 { + if project.host_connection_id == connection_id.0 { let room_id = project.room_id; project::Entity::delete(project.into_active_model()) .exec(&tx) @@ -1562,7 +1562,7 @@ impl Database { .ok_or_else(|| anyhow!("no such project"))?; worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel { - id: ActiveValue::set(worktree.id as i32), + id: ActiveValue::set(WorktreeId(worktree.id as u32)), project_id: ActiveValue::set(project.id), abs_path: ActiveValue::set(worktree.abs_path.clone()), root_name: ActiveValue::set(worktree.root_name.clone()), @@ -1578,7 +1578,7 @@ impl Database { worktree::Column::Id.is_not_in( worktrees .iter() - .map(|worktree| WorktreeId(worktree.id as i32)), + .map(|worktree| WorktreeId(worktree.id as u32)), ), ), ) @@ -2093,7 +2093,7 @@ impl Database { if collaborators .iter() - .any(|collaborator| collaborator.connection_id == connection_id.0 as i32) + .any(|collaborator| collaborator.connection_id == connection_id.0) { Ok(collaborators) } else { @@ -2307,8 +2307,8 @@ impl DerefMut for RoomGuard { #[derive(Debug, Serialize, Deserialize)] pub struct NewUserParams { pub github_login: String, - pub github_user_id: i32, - pub invite_count: i32, + pub github_user_id: u32, + pub invite_count: u32, } #[derive(Debug)] @@ -2339,21 +2339,19 @@ macro_rules! id_type { PartialOrd, Ord, Hash, - sqlx::Type, Serialize, Deserialize, )] - #[sqlx(transparent)] #[serde(transparent)] - pub struct $name(pub i32); + pub struct $name(pub u32); impl $name { #[allow(unused)] - pub const MAX: Self = Self(i32::MAX); + pub const MAX: Self = Self(u32::MAX); #[allow(unused)] pub fn from_proto(value: u64) -> Self { - Self(value as i32) + Self(value as u32) } #[allow(unused)] @@ -2370,7 +2368,7 @@ macro_rules! id_type { impl From<$name> for sea_query::Value { fn from(value: $name) -> Self { - sea_query::Value::Int(Some(value.0)) + sea_query::Value::Unsigned(Some(value.0)) } } @@ -2380,7 +2378,7 @@ macro_rules! id_type { pre: &str, col: &str, ) -> Result { - Ok(Self(i32::try_get(res, pre, col)?)) + Ok(Self(u32::try_get(res, pre, col)?)) } } @@ -2420,11 +2418,11 @@ macro_rules! id_type { } fn array_type() -> sea_query::ArrayType { - sea_query::ArrayType::Int + sea_query::ArrayType::Unsigned } fn column_type() -> sea_query::ColumnType { - sea_query::ColumnType::Integer(None) + sea_query::ColumnType::Unsigned(None) } } diff --git a/crates/collab/src/db/project.rs b/crates/collab/src/db/project.rs index a9f0d1cb47d9b15c2cfa77a878f98c1456053385..c8083402a3041162c1280f08dd2d9d2d17e2bef0 100644 --- a/crates/collab/src/db/project.rs +++ b/crates/collab/src/db/project.rs @@ -8,7 +8,7 @@ pub struct Model { pub id: ProjectId, pub room_id: RoomId, pub host_user_id: UserId, - pub host_connection_id: i32, + pub host_connection_id: u32, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/crates/collab/src/db/project_collaborator.rs b/crates/collab/src/db/project_collaborator.rs index fb1d565e3a4bec8b8115e5b827293bf552f4a1aa..bccf451a633116425c6843c5da1fc7b1ae204b9e 100644 --- a/crates/collab/src/db/project_collaborator.rs +++ b/crates/collab/src/db/project_collaborator.rs @@ -7,7 +7,7 @@ pub struct Model { #[sea_orm(primary_key)] pub id: ProjectCollaboratorId, pub project_id: ProjectId, - pub connection_id: i32, + pub connection_id: u32, pub user_id: UserId, pub replica_id: ReplicaId, pub is_host: bool, diff --git a/crates/collab/src/db/room_participant.rs b/crates/collab/src/db/room_participant.rs index c7c804581b07be6825bbc27b44227d8da4a6b26a..e8f38cf69318937fce7857891f7a2d0f4384c512 100644 --- a/crates/collab/src/db/room_participant.rs +++ b/crates/collab/src/db/room_participant.rs @@ -8,12 +8,12 @@ pub struct Model { pub id: RoomParticipantId, pub room_id: RoomId, pub user_id: UserId, - pub answering_connection_id: Option, - pub location_kind: Option, + pub answering_connection_id: Option, + pub location_kind: Option, pub location_project_id: Option, pub initial_project_id: Option, pub calling_user_id: UserId, - pub calling_connection_id: i32, + pub calling_connection_id: u32, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index b276bd5057b7282815a4c21eeea00fd691eecff5..6ca287746a7c9106bb1230c9eb9e1245e0e179b0 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -430,7 +430,7 @@ async fn test_fuzzy_search_users() { false, NewUserParams { github_login: github_login.into(), - github_user_id: i as i32, + github_user_id: i as u32, invite_count: 0, }, ) diff --git a/crates/collab/src/db/user.rs b/crates/collab/src/db/user.rs index c2b157bd0a758880fd6fe64b079fa8760b59df5c..99292330447840124e1f471f75cc7de43f1560d3 100644 --- a/crates/collab/src/db/user.rs +++ b/crates/collab/src/db/user.rs @@ -8,11 +8,11 @@ pub struct Model { #[sea_orm(primary_key)] pub id: UserId, pub github_login: String, - pub github_user_id: Option, + pub github_user_id: Option, pub email_address: Option, pub admin: bool, pub invite_code: Option, - pub invite_count: i32, + pub invite_count: u32, pub inviter_id: Option, pub connected_once: bool, pub metrics_id: Uuid, diff --git a/crates/collab/src/db/worktree.rs b/crates/collab/src/db/worktree.rs index 3c6f7c0c1d62d274b3c2bc95e150678037117e96..8cad41e8a943860c4a622ac8fe34fde48f86f383 100644 --- a/crates/collab/src/db/worktree.rs +++ b/crates/collab/src/db/worktree.rs @@ -1,18 +1,17 @@ +use super::{ProjectId, WorktreeId}; use sea_orm::entity::prelude::*; -use super::ProjectId; - #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] #[sea_orm(table_name = "worktrees")] pub struct Model { #[sea_orm(primary_key)] - pub id: i32, + pub id: WorktreeId, #[sea_orm(primary_key)] pub project_id: ProjectId, pub abs_path: String, pub root_name: String, pub visible: bool, - pub scan_id: i64, + pub scan_id: u32, pub is_complete: bool, } diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index 225501c71d58fde0a3cbd9676c2ee9749dca3792..c5540284077317c9160a7278d00060bd068a41dd 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -5383,7 +5383,7 @@ async fn test_random_collaboration( false, NewUserParams { github_login: username.clone(), - github_user_id: (ix + 1) as i32, + github_user_id: (ix + 1) as u32, invite_count: 0, }, ) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index beefe54a9d6ee83b976c6d918c0f94efd87229e6..01866b074d0154bb254865a888fee78fa778f78a 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1008,7 +1008,7 @@ async fn join_project( let collaborators = project .collaborators .iter() - .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32) + .filter(|collaborator| collaborator.connection_id != session.connection_id.0) .map(|collaborator| proto::Collaborator { peer_id: collaborator.connection_id as u32, replica_id: collaborator.replica_id.0 as u32, @@ -1313,8 +1313,7 @@ async fn save_buffer( .await .project_collaborators(project_id, session.connection_id) .await?; - collaborators - .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32); + collaborators.retain(|collaborator| collaborator.connection_id != session.connection_id.0); let project_connection_ids = collaborators .into_iter() .map(|collaborator| ConnectionId(collaborator.connection_id as u32)); From 585ac3e1beb6aea75f929e7e80116b4c081acfa0 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Thu, 1 Dec 2022 18:39:24 +0100 Subject: [PATCH 091/109] WIP --- crates/collab/src/db.rs | 60 +++++++++++--------------- crates/collab/src/db/worktree_entry.rs | 23 ++++++++++ 2 files changed, 47 insertions(+), 36 deletions(-) create mode 100644 crates/collab/src/db/worktree_entry.rs diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 31ee381857af3db7b5f10c43411b1cde3b43359e..3d828b2e79fd48a3615374ad8e51019901d31349 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -9,6 +9,7 @@ mod signup; mod tests; mod user; mod worktree; +mod worktree_entry; use crate::{Error, Result}; use anyhow::anyhow; @@ -1599,44 +1600,28 @@ impl Database { connection_id: ConnectionId, ) -> Result>> { self.transact(|tx| async move { - todo!() - // let project_id = ProjectId::from_proto(update.project_id); - // let worktree_id = WorktreeId::from_proto(update.worktree_id); + let project_id = ProjectId::from_proto(update.project_id); + let worktree_id = WorktreeId::from_proto(update.worktree_id); - // // Ensure the update comes from the host. - // let room_id: RoomId = sqlx::query_scalar( - // " - // SELECT room_id - // FROM projects - // WHERE id = $1 AND host_connection_id = $2 - // ", - // ) - // .bind(project_id) - // .bind(connection_id.0 as i32) - // .fetch_one(&mut tx) - // .await?; + // Ensure the update comes from the host. + let project = project::Entity::find_by_id(project_id) + .filter(project::Column::HostConnectionId.eq(connection_id.0)) + .one(&tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; - // // Update metadata. - // sqlx::query( - // " - // UPDATE worktrees - // SET - // root_name = $1, - // scan_id = $2, - // is_complete = $3, - // abs_path = $4 - // WHERE project_id = $5 AND id = $6 - // RETURNING 1 - // ", - // ) - // .bind(&update.root_name) - // .bind(update.scan_id as i64) - // .bind(update.is_last_update) - // .bind(&update.abs_path) - // .bind(project_id) - // .bind(worktree_id) - // .fetch_one(&mut tx) - // .await?; + // Update metadata. + worktree::Entity::update(worktree::ActiveModel { + id: ActiveValue::set(worktree_id), + project_id: ActiveValue::set(project_id), + root_name: ActiveValue::set(update.root_name.clone()), + scan_id: ActiveValue::set(update.scan_id as u32), + is_complete: ActiveValue::set(update.is_last_update), + abs_path: ActiveValue::set(update.abs_path.clone()), + ..Default::default() + }) + .exec(&tx) + .await?; // if !update.updated_entries.is_empty() { // let mut params = @@ -1706,6 +1691,8 @@ impl Database { // let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; // self.commit_room_transaction(room_id, tx, connection_ids) // .await + + todo!() }) .await } @@ -2456,6 +2443,7 @@ id_type!(ReplicaId); id_type!(SignupId); id_type!(UserId); id_type!(WorktreeId); +id_type!(WorktreeEntryId); pub struct LeftRoom { pub room: proto::Room, diff --git a/crates/collab/src/db/worktree_entry.rs b/crates/collab/src/db/worktree_entry.rs new file mode 100644 index 0000000000000000000000000000000000000000..8698d844c107eed7674d6cd7e14505ffce7d4ed4 --- /dev/null +++ b/crates/collab/src/db/worktree_entry.rs @@ -0,0 +1,23 @@ +use super::{ProjectId, WorktreeEntryId, WorktreeId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "worktree_entries")] +pub struct Model { + #[sea_orm(primary_key)] + project_id: ProjectId, + #[sea_orm(primary_key)] + worktree_id: WorktreeId, + #[sea_orm(primary_key)] + id: WorktreeEntryId, + is_dir: bool, + path: String, + inode: u64, + mtime_seconds: u64, + mtime_nanos: u32, + is_symlink: bool, + is_ignored: bool, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} From dec5f37e4e4f13abb33cc5717f58390496bcf32c Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 2 Dec 2022 13:58:23 +0100 Subject: [PATCH 092/109] Finish porting remaining db methods to sea-orm --- .../20221109000000_test_schema.sql | 2 + .../20221111092550_reconnection_support.sql | 6 +- crates/collab/src/db.rs | 721 ++++++++---------- crates/collab/src/db/language_server.rs | 30 + crates/collab/src/db/project.rs | 20 +- crates/collab/src/db/project_collaborator.rs | 15 +- crates/collab/src/db/worktree.rs | 6 +- .../src/db/worktree_diagnostic_summary.rs | 21 + crates/collab/src/db/worktree_entry.rs | 24 +- crates/collab/src/rpc.rs | 6 +- 10 files changed, 416 insertions(+), 435 deletions(-) create mode 100644 crates/collab/src/db/language_server.rs create mode 100644 crates/collab/src/db/worktree_diagnostic_summary.rs diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index aeb6b7f720100d6ef72bcc5221d31747de372682..e62f834fbf07ca4d4265e1f7d710323413193e64 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -72,6 +72,7 @@ CREATE TABLE "worktree_entries" ( PRIMARY KEY(project_id, worktree_id, id), FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE ); +CREATE INDEX "index_worktree_entries_on_project_id" ON "worktree_entries" ("project_id"); CREATE INDEX "index_worktree_entries_on_project_id_and_worktree_id" ON "worktree_entries" ("project_id", "worktree_id"); CREATE TABLE "worktree_diagnostic_summaries" ( @@ -84,6 +85,7 @@ CREATE TABLE "worktree_diagnostic_summaries" ( PRIMARY KEY(project_id, worktree_id, path), FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE ); +CREATE INDEX "index_worktree_diagnostic_summaries_on_project_id" ON "worktree_diagnostic_summaries" ("project_id"); CREATE INDEX "index_worktree_diagnostic_summaries_on_project_id_and_worktree_id" ON "worktree_diagnostic_summaries" ("project_id", "worktree_id"); CREATE TABLE "language_servers" ( diff --git a/crates/collab/migrations/20221111092550_reconnection_support.sql b/crates/collab/migrations/20221111092550_reconnection_support.sql index b742f8e0cd0b2595641b77f756687ad17cdd9aba..a7d45a9759d300624173edffdc4bd0f28d575c34 100644 --- a/crates/collab/migrations/20221111092550_reconnection_support.sql +++ b/crates/collab/migrations/20221111092550_reconnection_support.sql @@ -22,18 +22,19 @@ CREATE INDEX "index_worktrees_on_project_id" ON "worktrees" ("project_id"); CREATE TABLE "worktree_entries" ( "project_id" INTEGER NOT NULL, - "worktree_id" INTEGER NOT NULL, + "worktree_id" INT8 NOT NULL, "id" INTEGER NOT NULL, "is_dir" BOOL NOT NULL, "path" VARCHAR NOT NULL, "inode" INT8 NOT NULL, - "mtime_seconds" INTEGER NOT NULL, + "mtime_seconds" INT8 NOT NULL, "mtime_nanos" INTEGER NOT NULL, "is_symlink" BOOL NOT NULL, "is_ignored" BOOL NOT NULL, PRIMARY KEY(project_id, worktree_id, id), FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE ); +CREATE INDEX "index_worktree_entries_on_project_id" ON "worktree_entries" ("project_id"); CREATE INDEX "index_worktree_entries_on_project_id_and_worktree_id" ON "worktree_entries" ("project_id", "worktree_id"); CREATE TABLE "worktree_diagnostic_summaries" ( @@ -46,6 +47,7 @@ CREATE TABLE "worktree_diagnostic_summaries" ( PRIMARY KEY(project_id, worktree_id, path), FOREIGN KEY(project_id, worktree_id) REFERENCES worktrees (project_id, id) ON DELETE CASCADE ); +CREATE INDEX "index_worktree_diagnostic_summaries_on_project_id" ON "worktree_diagnostic_summaries" ("project_id"); CREATE INDEX "index_worktree_diagnostic_summaries_on_project_id_and_worktree_id" ON "worktree_diagnostic_summaries" ("project_id", "worktree_id"); CREATE TABLE "language_servers" ( diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 3d828b2e79fd48a3615374ad8e51019901d31349..b01c6e750414d775fdd3b1c883c01a58fdbf88d4 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1,5 +1,6 @@ mod access_token; mod contact; +mod language_server; mod project; mod project_collaborator; mod room; @@ -9,6 +10,7 @@ mod signup; mod tests; mod user; mod worktree; +mod worktree_diagnostic_summary; mod worktree_entry; use crate::{Error, Result}; @@ -1493,7 +1495,7 @@ impl Database { .await?; worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel { - id: ActiveValue::set(WorktreeId(worktree.id as u32)), + id: ActiveValue::set(worktree.id as i64), project_id: ActiveValue::set(project.id), abs_path: ActiveValue::set(worktree.abs_path.clone()), root_name: ActiveValue::set(worktree.root_name.clone()), @@ -1563,7 +1565,7 @@ impl Database { .ok_or_else(|| anyhow!("no such project"))?; worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel { - id: ActiveValue::set(WorktreeId(worktree.id as u32)), + id: ActiveValue::set(worktree.id as i64), project_id: ActiveValue::set(project.id), abs_path: ActiveValue::set(worktree.abs_path.clone()), root_name: ActiveValue::set(worktree.root_name.clone()), @@ -1576,11 +1578,8 @@ impl Database { worktree::Entity::delete_many() .filter( worktree::Column::ProjectId.eq(project.id).and( - worktree::Column::Id.is_not_in( - worktrees - .iter() - .map(|worktree| WorktreeId(worktree.id as u32)), - ), + worktree::Column::Id + .is_not_in(worktrees.iter().map(|worktree| worktree.id as i64)), ), ) .exec(&tx) @@ -1601,7 +1600,7 @@ impl Database { ) -> Result>> { self.transact(|tx| async move { let project_id = ProjectId::from_proto(update.project_id); - let worktree_id = WorktreeId::from_proto(update.worktree_id); + let worktree_id = update.worktree_id as i64; // Ensure the update comes from the host. let project = project::Entity::find_by_id(project_id) @@ -1609,13 +1608,14 @@ impl Database { .one(&tx) .await? .ok_or_else(|| anyhow!("no such project"))?; + let room_id = project.room_id; // Update metadata. worktree::Entity::update(worktree::ActiveModel { id: ActiveValue::set(worktree_id), project_id: ActiveValue::set(project_id), root_name: ActiveValue::set(update.root_name.clone()), - scan_id: ActiveValue::set(update.scan_id as u32), + scan_id: ActiveValue::set(update.scan_id as i64), is_complete: ActiveValue::set(update.is_last_update), abs_path: ActiveValue::set(update.abs_path.clone()), ..Default::default() @@ -1623,76 +1623,57 @@ impl Database { .exec(&tx) .await?; - // if !update.updated_entries.is_empty() { - // let mut params = - // "(?, ?, ?, ?, ?, ?, ?, ?, ?, ?),".repeat(update.updated_entries.len()); - // params.pop(); - - // let query = format!( - // " - // INSERT INTO worktree_entries ( - // project_id, - // worktree_id, - // id, - // is_dir, - // path, - // inode, - // mtime_seconds, - // mtime_nanos, - // is_symlink, - // is_ignored - // ) - // VALUES {params} - // ON CONFLICT (project_id, worktree_id, id) DO UPDATE SET - // is_dir = excluded.is_dir, - // path = excluded.path, - // inode = excluded.inode, - // mtime_seconds = excluded.mtime_seconds, - // mtime_nanos = excluded.mtime_nanos, - // is_symlink = excluded.is_symlink, - // is_ignored = excluded.is_ignored - // " - // ); - // let mut query = sqlx::query(&query); - // for entry in &update.updated_entries { - // let mtime = entry.mtime.clone().unwrap_or_default(); - // query = query - // .bind(project_id) - // .bind(worktree_id) - // .bind(entry.id as i64) - // .bind(entry.is_dir) - // .bind(&entry.path) - // .bind(entry.inode as i64) - // .bind(mtime.seconds as i64) - // .bind(mtime.nanos as i32) - // .bind(entry.is_symlink) - // .bind(entry.is_ignored); - // } - // query.execute(&mut tx).await?; - // } - - // if !update.removed_entries.is_empty() { - // let mut params = "?,".repeat(update.removed_entries.len()); - // params.pop(); - // let query = format!( - // " - // DELETE FROM worktree_entries - // WHERE project_id = ? AND worktree_id = ? AND id IN ({params}) - // " - // ); - - // let mut query = sqlx::query(&query).bind(project_id).bind(worktree_id); - // for entry_id in &update.removed_entries { - // query = query.bind(*entry_id as i64); - // } - // query.execute(&mut tx).await?; - // } - - // let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; - // self.commit_room_transaction(room_id, tx, connection_ids) - // .await - - todo!() + worktree_entry::Entity::insert_many(update.updated_entries.iter().map(|entry| { + let mtime = entry.mtime.clone().unwrap_or_default(); + worktree_entry::ActiveModel { + project_id: ActiveValue::set(project_id), + worktree_id: ActiveValue::set(worktree_id), + id: ActiveValue::set(entry.id as i64), + is_dir: ActiveValue::set(entry.is_dir), + path: ActiveValue::set(entry.path.clone()), + inode: ActiveValue::set(entry.inode as i64), + mtime_seconds: ActiveValue::set(mtime.seconds as i64), + mtime_nanos: ActiveValue::set(mtime.nanos), + is_symlink: ActiveValue::set(entry.is_symlink), + is_ignored: ActiveValue::set(entry.is_ignored), + } + })) + .on_conflict( + OnConflict::columns([ + worktree_entry::Column::ProjectId, + worktree_entry::Column::WorktreeId, + worktree_entry::Column::Id, + ]) + .update_columns([ + worktree_entry::Column::IsDir, + worktree_entry::Column::Path, + worktree_entry::Column::Inode, + worktree_entry::Column::MtimeSeconds, + worktree_entry::Column::MtimeNanos, + worktree_entry::Column::IsSymlink, + worktree_entry::Column::IsIgnored, + ]) + .to_owned(), + ) + .exec(&tx) + .await?; + + worktree_entry::Entity::delete_many() + .filter( + worktree_entry::Column::ProjectId + .eq(project_id) + .and(worktree_entry::Column::WorktreeId.eq(worktree_id)) + .and( + worktree_entry::Column::Id + .is_in(update.removed_entries.iter().map(|id| *id as i64)), + ), + ) + .exec(&tx) + .await?; + + let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?; + self.commit_room_transaction(room_id, tx, connection_ids) + .await }) .await } @@ -1703,57 +1684,51 @@ impl Database { connection_id: ConnectionId, ) -> Result>> { self.transact(|tx| async { - todo!() - // let project_id = ProjectId::from_proto(update.project_id); - // let worktree_id = WorktreeId::from_proto(update.worktree_id); - // let summary = update - // .summary - // .as_ref() - // .ok_or_else(|| anyhow!("invalid summary"))?; - - // // Ensure the update comes from the host. - // let room_id: RoomId = sqlx::query_scalar( - // " - // SELECT room_id - // FROM projects - // WHERE id = $1 AND host_connection_id = $2 - // ", - // ) - // .bind(project_id) - // .bind(connection_id.0 as i32) - // .fetch_one(&mut tx) - // .await?; - - // // Update summary. - // sqlx::query( - // " - // INSERT INTO worktree_diagnostic_summaries ( - // project_id, - // worktree_id, - // path, - // language_server_id, - // error_count, - // warning_count - // ) - // VALUES ($1, $2, $3, $4, $5, $6) - // ON CONFLICT (project_id, worktree_id, path) DO UPDATE SET - // language_server_id = excluded.language_server_id, - // error_count = excluded.error_count, - // warning_count = excluded.warning_count - // ", - // ) - // .bind(project_id) - // .bind(worktree_id) - // .bind(&summary.path) - // .bind(summary.language_server_id as i64) - // .bind(summary.error_count as i32) - // .bind(summary.warning_count as i32) - // .execute(&mut tx) - // .await?; - - // let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; - // self.commit_room_transaction(room_id, tx, connection_ids) - // .await + let project_id = ProjectId::from_proto(update.project_id); + let worktree_id = update.worktree_id as i64; + let summary = update + .summary + .as_ref() + .ok_or_else(|| anyhow!("invalid summary"))?; + + // Ensure the update comes from the host. + let project = project::Entity::find_by_id(project_id) + .one(&tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; + if project.host_connection_id != connection_id.0 { + return Err(anyhow!("can't update a project hosted by someone else"))?; + } + + // Update summary. + worktree_diagnostic_summary::Entity::insert(worktree_diagnostic_summary::ActiveModel { + project_id: ActiveValue::set(project_id), + worktree_id: ActiveValue::set(worktree_id), + path: ActiveValue::set(summary.path.clone()), + language_server_id: ActiveValue::set(summary.language_server_id as i64), + error_count: ActiveValue::set(summary.error_count), + warning_count: ActiveValue::set(summary.warning_count), + ..Default::default() + }) + .on_conflict( + OnConflict::columns([ + worktree_diagnostic_summary::Column::ProjectId, + worktree_diagnostic_summary::Column::WorktreeId, + worktree_diagnostic_summary::Column::Path, + ]) + .update_columns([ + worktree_diagnostic_summary::Column::LanguageServerId, + worktree_diagnostic_summary::Column::ErrorCount, + worktree_diagnostic_summary::Column::WarningCount, + ]) + .to_owned(), + ) + .exec(&tx) + .await?; + + let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?; + self.commit_room_transaction(project.room_id, tx, connection_ids) + .await }) .await } @@ -1764,44 +1739,42 @@ impl Database { connection_id: ConnectionId, ) -> Result>> { self.transact(|tx| async { - todo!() - // let project_id = ProjectId::from_proto(update.project_id); - // let server = update - // .server - // .as_ref() - // .ok_or_else(|| anyhow!("invalid language server"))?; - - // // Ensure the update comes from the host. - // let room_id: RoomId = sqlx::query_scalar( - // " - // SELECT room_id - // FROM projects - // WHERE id = $1 AND host_connection_id = $2 - // ", - // ) - // .bind(project_id) - // .bind(connection_id.0 as i32) - // .fetch_one(&mut tx) - // .await?; - - // // Add the newly-started language server. - // sqlx::query( - // " - // INSERT INTO language_servers (project_id, id, name) - // VALUES ($1, $2, $3) - // ON CONFLICT (project_id, id) DO UPDATE SET - // name = excluded.name - // ", - // ) - // .bind(project_id) - // .bind(server.id as i64) - // .bind(&server.name) - // .execute(&mut tx) - // .await?; - - // let connection_ids = self.get_guest_connection_ids(project_id, &mut tx).await?; - // self.commit_room_transaction(room_id, tx, connection_ids) - // .await + let project_id = ProjectId::from_proto(update.project_id); + let server = update + .server + .as_ref() + .ok_or_else(|| anyhow!("invalid language server"))?; + + // Ensure the update comes from the host. + let project = project::Entity::find_by_id(project_id) + .one(&tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; + if project.host_connection_id != connection_id.0 { + return Err(anyhow!("can't update a project hosted by someone else"))?; + } + + // Add the newly-started language server. + language_server::Entity::insert(language_server::ActiveModel { + project_id: ActiveValue::set(project_id), + id: ActiveValue::set(server.id as i64), + name: ActiveValue::set(server.name.clone()), + ..Default::default() + }) + .on_conflict( + OnConflict::columns([ + language_server::Column::ProjectId, + language_server::Column::Id, + ]) + .update_column(language_server::Column::Name) + .to_owned(), + ) + .exec(&tx) + .await?; + + let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?; + self.commit_room_transaction(project.room_id, tx, connection_ids) + .await }) .await } @@ -1812,194 +1785,135 @@ impl Database { connection_id: ConnectionId, ) -> Result> { self.transact(|tx| async move { - todo!() - // let (room_id, user_id) = sqlx::query_as::<_, (RoomId, UserId)>( - // " - // SELECT room_id, user_id - // FROM room_participants - // WHERE answering_connection_id = $1 - // ", - // ) - // .bind(connection_id.0 as i32) - // .fetch_one(&mut tx) - // .await?; - - // // Ensure project id was shared on this room. - // sqlx::query( - // " - // SELECT 1 - // FROM projects - // WHERE id = $1 AND room_id = $2 - // ", - // ) - // .bind(project_id) - // .bind(room_id) - // .fetch_one(&mut tx) - // .await?; - - // let mut collaborators = sqlx::query_as::<_, ProjectCollaborator>( - // " - // SELECT * - // FROM project_collaborators - // WHERE project_id = $1 - // ", - // ) - // .bind(project_id) - // .fetch_all(&mut tx) - // .await?; - // let replica_ids = collaborators - // .iter() - // .map(|c| c.replica_id) - // .collect::>(); - // let mut replica_id = ReplicaId(1); - // while replica_ids.contains(&replica_id) { - // replica_id.0 += 1; - // } - // let new_collaborator = ProjectCollaborator { - // project_id, - // connection_id: connection_id.0 as i32, - // user_id, - // replica_id, - // is_host: false, - // }; - - // sqlx::query( - // " - // INSERT INTO project_collaborators ( - // project_id, - // connection_id, - // user_id, - // replica_id, - // is_host - // ) - // VALUES ($1, $2, $3, $4, $5) - // ", - // ) - // .bind(new_collaborator.project_id) - // .bind(new_collaborator.connection_id) - // .bind(new_collaborator.user_id) - // .bind(new_collaborator.replica_id) - // .bind(new_collaborator.is_host) - // .execute(&mut tx) - // .await?; - // collaborators.push(new_collaborator); - - // let worktree_rows = sqlx::query_as::<_, WorktreeRow>( - // " - // SELECT * - // FROM worktrees - // WHERE project_id = $1 - // ", - // ) - // .bind(project_id) - // .fetch_all(&mut tx) - // .await?; - // let mut worktrees = worktree_rows - // .into_iter() - // .map(|worktree_row| { - // ( - // worktree_row.id, - // Worktree { - // id: worktree_row.id, - // abs_path: worktree_row.abs_path, - // root_name: worktree_row.root_name, - // visible: worktree_row.visible, - // entries: Default::default(), - // diagnostic_summaries: Default::default(), - // scan_id: worktree_row.scan_id as u64, - // is_complete: worktree_row.is_complete, - // }, - // ) - // }) - // .collect::>(); - - // // Populate worktree entries. - // { - // let mut entries = sqlx::query_as::<_, WorktreeEntry>( - // " - // SELECT * - // FROM worktree_entries - // WHERE project_id = $1 - // ", - // ) - // .bind(project_id) - // .fetch(&mut tx); - // while let Some(entry) = entries.next().await { - // let entry = entry?; - // if let Some(worktree) = worktrees.get_mut(&entry.worktree_id) { - // worktree.entries.push(proto::Entry { - // id: entry.id as u64, - // is_dir: entry.is_dir, - // path: entry.path, - // inode: entry.inode as u64, - // mtime: Some(proto::Timestamp { - // seconds: entry.mtime_seconds as u64, - // nanos: entry.mtime_nanos as u32, - // }), - // is_symlink: entry.is_symlink, - // is_ignored: entry.is_ignored, - // }); - // } - // } - // } - - // // Populate worktree diagnostic summaries. - // { - // let mut summaries = sqlx::query_as::<_, WorktreeDiagnosticSummary>( - // " - // SELECT * - // FROM worktree_diagnostic_summaries - // WHERE project_id = $1 - // ", - // ) - // .bind(project_id) - // .fetch(&mut tx); - // while let Some(summary) = summaries.next().await { - // let summary = summary?; - // if let Some(worktree) = worktrees.get_mut(&summary.worktree_id) { - // worktree - // .diagnostic_summaries - // .push(proto::DiagnosticSummary { - // path: summary.path, - // language_server_id: summary.language_server_id as u64, - // error_count: summary.error_count as u32, - // warning_count: summary.warning_count as u32, - // }); - // } - // } - // } - - // // Populate language servers. - // let language_servers = sqlx::query_as::<_, LanguageServer>( - // " - // SELECT * - // FROM language_servers - // WHERE project_id = $1 - // ", - // ) - // .bind(project_id) - // .fetch_all(&mut tx) - // .await?; - - // self.commit_room_transaction( - // room_id, - // tx, - // ( - // Project { - // collaborators, - // worktrees, - // language_servers: language_servers - // .into_iter() - // .map(|language_server| proto::LanguageServer { - // id: language_server.id.to_proto(), - // name: language_server.name, - // }) - // .collect(), - // }, - // replica_id as ReplicaId, - // ), - // ) - // .await + let participant = room_participant::Entity::find() + .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0)) + .one(&tx) + .await? + .ok_or_else(|| anyhow!("must join a room first"))?; + + let project = project::Entity::find_by_id(project_id) + .one(&tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; + if project.room_id != participant.room_id { + return Err(anyhow!("no such project"))?; + } + + let mut collaborators = project + .find_related(project_collaborator::Entity) + .all(&tx) + .await?; + let replica_ids = collaborators + .iter() + .map(|c| c.replica_id) + .collect::>(); + let mut replica_id = ReplicaId(1); + while replica_ids.contains(&replica_id) { + replica_id.0 += 1; + } + let new_collaborator = project_collaborator::ActiveModel { + project_id: ActiveValue::set(project_id), + connection_id: ActiveValue::set(connection_id.0), + user_id: ActiveValue::set(participant.user_id), + replica_id: ActiveValue::set(replica_id), + is_host: ActiveValue::set(false), + ..Default::default() + } + .insert(&tx) + .await?; + collaborators.push(new_collaborator); + + let db_worktrees = project.find_related(worktree::Entity).all(&tx).await?; + let mut worktrees = db_worktrees + .into_iter() + .map(|db_worktree| { + ( + db_worktree.id as u64, + Worktree { + id: db_worktree.id as u64, + abs_path: db_worktree.abs_path, + root_name: db_worktree.root_name, + visible: db_worktree.visible, + entries: Default::default(), + diagnostic_summaries: Default::default(), + scan_id: db_worktree.scan_id as u64, + is_complete: db_worktree.is_complete, + }, + ) + }) + .collect::>(); + + // Populate worktree entries. + { + let mut db_entries = worktree_entry::Entity::find() + .filter(worktree_entry::Column::ProjectId.eq(project_id)) + .stream(&tx) + .await?; + while let Some(db_entry) = db_entries.next().await { + let db_entry = db_entry?; + if let Some(worktree) = worktrees.get_mut(&(db_entry.worktree_id as u64)) { + worktree.entries.push(proto::Entry { + id: db_entry.id as u64, + is_dir: db_entry.is_dir, + path: db_entry.path, + inode: db_entry.inode as u64, + mtime: Some(proto::Timestamp { + seconds: db_entry.mtime_seconds as u64, + nanos: db_entry.mtime_nanos, + }), + is_symlink: db_entry.is_symlink, + is_ignored: db_entry.is_ignored, + }); + } + } + } + + // Populate worktree diagnostic summaries. + { + let mut db_summaries = worktree_diagnostic_summary::Entity::find() + .filter(worktree_diagnostic_summary::Column::ProjectId.eq(project_id)) + .stream(&tx) + .await?; + while let Some(db_summary) = db_summaries.next().await { + let db_summary = db_summary?; + if let Some(worktree) = worktrees.get_mut(&(db_summary.worktree_id as u64)) { + worktree + .diagnostic_summaries + .push(proto::DiagnosticSummary { + path: db_summary.path, + language_server_id: db_summary.language_server_id as u64, + error_count: db_summary.error_count as u32, + warning_count: db_summary.warning_count as u32, + }); + } + } + } + + // Populate language servers. + let language_servers = project + .find_related(language_server::Entity) + .all(&tx) + .await?; + + self.commit_room_transaction( + project.room_id, + tx, + ( + Project { + collaborators, + worktrees, + language_servers: language_servers + .into_iter() + .map(|language_server| proto::LanguageServer { + id: language_server.id as u64, + name: language_server.name, + }) + .collect(), + }, + replica_id as ReplicaId, + ), + ) + .await }) .await } @@ -2010,59 +1924,42 @@ impl Database { connection_id: ConnectionId, ) -> Result> { self.transact(|tx| async move { - todo!() - // let result = sqlx::query( - // " - // DELETE FROM project_collaborators - // WHERE project_id = $1 AND connection_id = $2 - // ", - // ) - // .bind(project_id) - // .bind(connection_id.0 as i32) - // .execute(&mut tx) - // .await?; - - // if result.rows_affected() == 0 { - // Err(anyhow!("not a collaborator on this project"))?; - // } - - // let connection_ids = sqlx::query_scalar::<_, i32>( - // " - // SELECT connection_id - // FROM project_collaborators - // WHERE project_id = $1 - // ", - // ) - // .bind(project_id) - // .fetch_all(&mut tx) - // .await? - // .into_iter() - // .map(|id| ConnectionId(id as u32)) - // .collect(); - - // let (room_id, host_user_id, host_connection_id) = - // sqlx::query_as::<_, (RoomId, i32, i32)>( - // " - // SELECT room_id, host_user_id, host_connection_id - // FROM projects - // WHERE id = $1 - // ", - // ) - // .bind(project_id) - // .fetch_one(&mut tx) - // .await?; - - // self.commit_room_transaction( - // room_id, - // tx, - // LeftProject { - // id: project_id, - // host_user_id: UserId(host_user_id), - // host_connection_id: ConnectionId(host_connection_id as u32), - // connection_ids, - // }, - // ) - // .await + let result = project_collaborator::Entity::delete_many() + .filter( + project_collaborator::Column::ProjectId + .eq(project_id) + .and(project_collaborator::Column::ConnectionId.eq(connection_id.0)), + ) + .exec(&tx) + .await?; + if result.rows_affected == 0 { + Err(anyhow!("not a collaborator on this project"))?; + } + + let project = project::Entity::find_by_id(project_id) + .one(&tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; + let collaborators = project + .find_related(project_collaborator::Entity) + .all(&tx) + .await?; + let connection_ids = collaborators + .into_iter() + .map(|collaborator| ConnectionId(collaborator.connection_id)) + .collect(); + + self.commit_room_transaction( + project.room_id, + tx, + LeftProject { + id: project_id, + host_user_id: project.host_user_id, + host_connection_id: ConnectionId(project.host_connection_id), + connection_ids, + }, + ) + .await }) .await } @@ -2442,8 +2339,6 @@ id_type!(ProjectCollaboratorId); id_type!(ReplicaId); id_type!(SignupId); id_type!(UserId); -id_type!(WorktreeId); -id_type!(WorktreeEntryId); pub struct LeftRoom { pub room: proto::Room, @@ -2453,7 +2348,7 @@ pub struct LeftRoom { pub struct Project { pub collaborators: Vec, - pub worktrees: BTreeMap, + pub worktrees: BTreeMap, pub language_servers: Vec, } @@ -2465,7 +2360,7 @@ pub struct LeftProject { } pub struct Worktree { - pub id: WorktreeId, + pub id: u64, pub abs_path: String, pub root_name: String, pub visible: bool, diff --git a/crates/collab/src/db/language_server.rs b/crates/collab/src/db/language_server.rs new file mode 100644 index 0000000000000000000000000000000000000000..d2c045c12146e1c8797c4cbd7e1ae52c52829e98 --- /dev/null +++ b/crates/collab/src/db/language_server.rs @@ -0,0 +1,30 @@ +use super::ProjectId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "language_servers")] +pub struct Model { + #[sea_orm(primary_key)] + pub project_id: ProjectId, + #[sea_orm(primary_key)] + pub id: i64, + pub name: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::project::Entity", + from = "Column::ProjectId", + to = "super::project::Column::Id" + )] + Project, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Project.def() + } +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/project.rs b/crates/collab/src/db/project.rs index c8083402a3041162c1280f08dd2d9d2d17e2bef0..5bf8addec8e7533da78cbc866d3e38e929dabb19 100644 --- a/crates/collab/src/db/project.rs +++ b/crates/collab/src/db/project.rs @@ -26,7 +26,11 @@ pub enum Relation { )] Room, #[sea_orm(has_many = "super::worktree::Entity")] - Worktree, + Worktrees, + #[sea_orm(has_many = "super::project_collaborator::Entity")] + Collaborators, + #[sea_orm(has_many = "super::language_server::Entity")] + LanguageServers, } impl Related for Entity { @@ -43,7 +47,19 @@ impl Related for Entity { impl Related for Entity { fn to() -> RelationDef { - Relation::Worktree.def() + Relation::Worktrees.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Collaborators.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::LanguageServers.def() } } diff --git a/crates/collab/src/db/project_collaborator.rs b/crates/collab/src/db/project_collaborator.rs index bccf451a633116425c6843c5da1fc7b1ae204b9e..56048c318150e62c3d4bbe8eefcbb085cdf6153a 100644 --- a/crates/collab/src/db/project_collaborator.rs +++ b/crates/collab/src/db/project_collaborator.rs @@ -14,6 +14,19 @@ pub struct Model { } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] -pub enum Relation {} +pub enum Relation { + #[sea_orm( + belongs_to = "super::project::Entity", + from = "Column::ProjectId", + to = "super::project::Column::Id" + )] + Project, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Project.def() + } +} impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/worktree.rs b/crates/collab/src/db/worktree.rs index 8cad41e8a943860c4a622ac8fe34fde48f86f383..b9f0f97dee05b71558a050fa808b62f56b2aefd1 100644 --- a/crates/collab/src/db/worktree.rs +++ b/crates/collab/src/db/worktree.rs @@ -1,17 +1,17 @@ -use super::{ProjectId, WorktreeId}; +use super::ProjectId; use sea_orm::entity::prelude::*; #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] #[sea_orm(table_name = "worktrees")] pub struct Model { #[sea_orm(primary_key)] - pub id: WorktreeId, + pub id: i64, #[sea_orm(primary_key)] pub project_id: ProjectId, pub abs_path: String, pub root_name: String, pub visible: bool, - pub scan_id: u32, + pub scan_id: i64, pub is_complete: bool, } diff --git a/crates/collab/src/db/worktree_diagnostic_summary.rs b/crates/collab/src/db/worktree_diagnostic_summary.rs new file mode 100644 index 0000000000000000000000000000000000000000..49bf4f6e033f42247373732ad002f838e7ce68ad --- /dev/null +++ b/crates/collab/src/db/worktree_diagnostic_summary.rs @@ -0,0 +1,21 @@ +use super::ProjectId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "worktree_diagnostic_summaries")] +pub struct Model { + #[sea_orm(primary_key)] + pub project_id: ProjectId, + #[sea_orm(primary_key)] + pub worktree_id: i64, + #[sea_orm(primary_key)] + pub path: String, + pub language_server_id: i64, + pub error_count: u32, + pub warning_count: u32, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/worktree_entry.rs b/crates/collab/src/db/worktree_entry.rs index 8698d844c107eed7674d6cd7e14505ffce7d4ed4..f38ef7b3f78de8675e6f1486570607d09aca71db 100644 --- a/crates/collab/src/db/worktree_entry.rs +++ b/crates/collab/src/db/worktree_entry.rs @@ -1,23 +1,25 @@ -use super::{ProjectId, WorktreeEntryId, WorktreeId}; +use super::ProjectId; use sea_orm::entity::prelude::*; #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] #[sea_orm(table_name = "worktree_entries")] pub struct Model { #[sea_orm(primary_key)] - project_id: ProjectId, + pub project_id: ProjectId, #[sea_orm(primary_key)] - worktree_id: WorktreeId, + pub worktree_id: i64, #[sea_orm(primary_key)] - id: WorktreeEntryId, - is_dir: bool, - path: String, - inode: u64, - mtime_seconds: u64, - mtime_nanos: u32, - is_symlink: bool, - is_ignored: bool, + pub id: i64, + pub is_dir: bool, + pub path: String, + pub inode: i64, + pub mtime_seconds: i64, + pub mtime_nanos: u32, + pub is_symlink: bool, + pub is_ignored: bool, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 01866b074d0154bb254865a888fee78fa778f78a..d3b95a82e692fab62178612d431dd88f1aa30df8 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1019,7 +1019,7 @@ async fn join_project( .worktrees .iter() .map(|(id, worktree)| proto::WorktreeMetadata { - id: id.to_proto(), + id: *id, root_name: worktree.root_name.clone(), visible: worktree.visible, abs_path: worktree.abs_path.clone(), @@ -1060,7 +1060,7 @@ async fn join_project( // Stream this worktree's entries. let message = proto::UpdateWorktree { project_id: project_id.to_proto(), - worktree_id: worktree_id.to_proto(), + worktree_id, abs_path: worktree.abs_path.clone(), root_name: worktree.root_name, updated_entries: worktree.entries, @@ -1078,7 +1078,7 @@ async fn join_project( session.connection_id, proto::UpdateDiagnosticSummary { project_id: project_id.to_proto(), - worktree_id: worktree.id.to_proto(), + worktree_id: worktree.id, summary: Some(summary), }, )?; From 48b6ee313f8777856489df4f3ad0e8f2f111ed05 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 2 Dec 2022 13:58:54 +0100 Subject: [PATCH 093/109] Use i32 to represent Postgres `INTEGER` types in Rust --- crates/collab/src/api.rs | 12 ++-- crates/collab/src/db.rs | 60 +++++++++---------- crates/collab/src/db/project.rs | 2 +- crates/collab/src/db/project_collaborator.rs | 2 +- crates/collab/src/db/room_participant.rs | 6 +- crates/collab/src/db/tests.rs | 2 +- crates/collab/src/db/user.rs | 4 +- .../src/db/worktree_diagnostic_summary.rs | 4 +- crates/collab/src/db/worktree_entry.rs | 2 +- crates/collab/src/integration_tests.rs | 2 +- crates/collab/src/rpc.rs | 7 ++- 11 files changed, 52 insertions(+), 51 deletions(-) diff --git a/crates/collab/src/api.rs b/crates/collab/src/api.rs index a0554947917b135ac59f3ed72548bd23289ef045..4c1c60a04f6aee8d7121820516ea49333b264d4d 100644 --- a/crates/collab/src/api.rs +++ b/crates/collab/src/api.rs @@ -76,7 +76,7 @@ pub async fn validate_api_token(req: Request, next: Next) -> impl IntoR #[derive(Debug, Deserialize)] struct AuthenticatedUserParams { - github_user_id: Option, + github_user_id: Option, github_login: String, } @@ -123,14 +123,14 @@ async fn get_users( #[derive(Deserialize, Debug)] struct CreateUserParams { - github_user_id: u32, + github_user_id: i32, github_login: String, email_address: String, email_confirmation_code: Option, #[serde(default)] admin: bool, #[serde(default)] - invite_count: u32, + invite_count: i32, } #[derive(Serialize, Debug)] @@ -204,11 +204,11 @@ async fn create_user( #[derive(Deserialize)] struct UpdateUserParams { admin: Option, - invite_count: Option, + invite_count: Option, } async fn update_user( - Path(user_id): Path, + Path(user_id): Path, Json(params): Json, Extension(app): Extension>, Extension(rpc_server): Extension>, @@ -230,7 +230,7 @@ async fn update_user( } async fn destroy_user( - Path(user_id): Path, + Path(user_id): Path, Extension(app): Extension>, ) -> Result<()> { app.db.destroy_user(UserId(user_id)).await?; diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index b01c6e750414d775fdd3b1c883c01a58fdbf88d4..945ac1b577f31944b454597ee1713be8b37c18d9 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -159,7 +159,7 @@ impl Database { pub async fn get_user_by_github_account( &self, github_login: &str, - github_user_id: Option, + github_user_id: Option, ) -> Result> { self.transact(|tx| async { let tx = tx; @@ -879,7 +879,7 @@ impl Database { .await } - pub async fn set_invite_count_for_user(&self, id: UserId, count: u32) -> Result<()> { + pub async fn set_invite_count_for_user(&self, id: UserId, count: i32) -> Result<()> { self.transact(|tx| async move { if count > 0 { user::Entity::update_many() @@ -910,11 +910,11 @@ impl Database { .await } - pub async fn get_invite_code_for_user(&self, id: UserId) -> Result> { + pub async fn get_invite_code_for_user(&self, id: UserId) -> Result> { self.transact(|tx| async move { match user::Entity::find_by_id(id).one(&tx).await? { Some(user) if user.invite_code.is_some() => { - Ok(Some((user.invite_code.unwrap(), user.invite_count as u32))) + Ok(Some((user.invite_code.unwrap(), user.invite_count))) } _ => Ok(None), } @@ -982,9 +982,9 @@ impl Database { room_participant::ActiveModel { room_id: ActiveValue::set(room_id), user_id: ActiveValue::set(user_id), - answering_connection_id: ActiveValue::set(Some(connection_id.0)), + answering_connection_id: ActiveValue::set(Some(connection_id.0 as i32)), calling_user_id: ActiveValue::set(user_id), - calling_connection_id: ActiveValue::set(connection_id.0), + calling_connection_id: ActiveValue::set(connection_id.0 as i32), ..Default::default() } .insert(&tx) @@ -1009,7 +1009,7 @@ impl Database { room_id: ActiveValue::set(room_id), user_id: ActiveValue::set(called_user_id), calling_user_id: ActiveValue::set(calling_user_id), - calling_connection_id: ActiveValue::set(calling_connection_id.0), + calling_connection_id: ActiveValue::set(calling_connection_id.0 as i32), initial_project_id: ActiveValue::set(initial_project_id), ..Default::default() } @@ -1126,7 +1126,7 @@ impl Database { .and(room_participant::Column::AnsweringConnectionId.is_null()), ) .set(room_participant::ActiveModel { - answering_connection_id: ActiveValue::set(Some(connection_id.0)), + answering_connection_id: ActiveValue::set(Some(connection_id.0 as i32)), ..Default::default() }) .exec(&tx) @@ -1488,7 +1488,7 @@ impl Database { let project = project::ActiveModel { room_id: ActiveValue::set(participant.room_id), host_user_id: ActiveValue::set(participant.user_id), - host_connection_id: ActiveValue::set(connection_id.0), + host_connection_id: ActiveValue::set(connection_id.0 as i32), ..Default::default() } .insert(&tx) @@ -1508,7 +1508,7 @@ impl Database { project_collaborator::ActiveModel { project_id: ActiveValue::set(project.id), - connection_id: ActiveValue::set(connection_id.0), + connection_id: ActiveValue::set(connection_id.0 as i32), user_id: ActiveValue::set(participant.user_id), replica_id: ActiveValue::set(ReplicaId(0)), is_host: ActiveValue::set(true), @@ -1536,7 +1536,7 @@ impl Database { .one(&tx) .await? .ok_or_else(|| anyhow!("project not found"))?; - if project.host_connection_id == connection_id.0 { + if project.host_connection_id == connection_id.0 as i32 { let room_id = project.room_id; project::Entity::delete(project.into_active_model()) .exec(&tx) @@ -1633,7 +1633,7 @@ impl Database { path: ActiveValue::set(entry.path.clone()), inode: ActiveValue::set(entry.inode as i64), mtime_seconds: ActiveValue::set(mtime.seconds as i64), - mtime_nanos: ActiveValue::set(mtime.nanos), + mtime_nanos: ActiveValue::set(mtime.nanos as i32), is_symlink: ActiveValue::set(entry.is_symlink), is_ignored: ActiveValue::set(entry.is_ignored), } @@ -1696,7 +1696,7 @@ impl Database { .one(&tx) .await? .ok_or_else(|| anyhow!("no such project"))?; - if project.host_connection_id != connection_id.0 { + if project.host_connection_id != connection_id.0 as i32 { return Err(anyhow!("can't update a project hosted by someone else"))?; } @@ -1706,8 +1706,8 @@ impl Database { worktree_id: ActiveValue::set(worktree_id), path: ActiveValue::set(summary.path.clone()), language_server_id: ActiveValue::set(summary.language_server_id as i64), - error_count: ActiveValue::set(summary.error_count), - warning_count: ActiveValue::set(summary.warning_count), + error_count: ActiveValue::set(summary.error_count as i32), + warning_count: ActiveValue::set(summary.warning_count as i32), ..Default::default() }) .on_conflict( @@ -1750,7 +1750,7 @@ impl Database { .one(&tx) .await? .ok_or_else(|| anyhow!("no such project"))?; - if project.host_connection_id != connection_id.0 { + if project.host_connection_id != connection_id.0 as i32 { return Err(anyhow!("can't update a project hosted by someone else"))?; } @@ -1813,7 +1813,7 @@ impl Database { } let new_collaborator = project_collaborator::ActiveModel { project_id: ActiveValue::set(project_id), - connection_id: ActiveValue::set(connection_id.0), + connection_id: ActiveValue::set(connection_id.0 as i32), user_id: ActiveValue::set(participant.user_id), replica_id: ActiveValue::set(replica_id), is_host: ActiveValue::set(false), @@ -1859,7 +1859,7 @@ impl Database { inode: db_entry.inode as u64, mtime: Some(proto::Timestamp { seconds: db_entry.mtime_seconds as u64, - nanos: db_entry.mtime_nanos, + nanos: db_entry.mtime_nanos as u32, }), is_symlink: db_entry.is_symlink, is_ignored: db_entry.is_ignored, @@ -1946,7 +1946,7 @@ impl Database { .await?; let connection_ids = collaborators .into_iter() - .map(|collaborator| ConnectionId(collaborator.connection_id)) + .map(|collaborator| ConnectionId(collaborator.connection_id as u32)) .collect(); self.commit_room_transaction( @@ -1955,7 +1955,7 @@ impl Database { LeftProject { id: project_id, host_user_id: project.host_user_id, - host_connection_id: ConnectionId(project.host_connection_id), + host_connection_id: ConnectionId(project.host_connection_id as u32), connection_ids, }, ) @@ -1977,7 +1977,7 @@ impl Database { if collaborators .iter() - .any(|collaborator| collaborator.connection_id == connection_id.0) + .any(|collaborator| collaborator.connection_id == connection_id.0 as i32) { Ok(collaborators) } else { @@ -2191,8 +2191,8 @@ impl DerefMut for RoomGuard { #[derive(Debug, Serialize, Deserialize)] pub struct NewUserParams { pub github_login: String, - pub github_user_id: u32, - pub invite_count: u32, + pub github_user_id: i32, + pub invite_count: i32, } #[derive(Debug)] @@ -2227,15 +2227,15 @@ macro_rules! id_type { Deserialize, )] #[serde(transparent)] - pub struct $name(pub u32); + pub struct $name(pub i32); impl $name { #[allow(unused)] - pub const MAX: Self = Self(u32::MAX); + pub const MAX: Self = Self(i32::MAX); #[allow(unused)] pub fn from_proto(value: u64) -> Self { - Self(value as u32) + Self(value as i32) } #[allow(unused)] @@ -2252,7 +2252,7 @@ macro_rules! id_type { impl From<$name> for sea_query::Value { fn from(value: $name) -> Self { - sea_query::Value::Unsigned(Some(value.0)) + sea_query::Value::Int(Some(value.0)) } } @@ -2262,7 +2262,7 @@ macro_rules! id_type { pre: &str, col: &str, ) -> Result { - Ok(Self(u32::try_get(res, pre, col)?)) + Ok(Self(i32::try_get(res, pre, col)?)) } } @@ -2302,11 +2302,11 @@ macro_rules! id_type { } fn array_type() -> sea_query::ArrayType { - sea_query::ArrayType::Unsigned + sea_query::ArrayType::Int } fn column_type() -> sea_query::ColumnType { - sea_query::ColumnType::Unsigned(None) + sea_query::ColumnType::Integer(None) } } diff --git a/crates/collab/src/db/project.rs b/crates/collab/src/db/project.rs index 5bf8addec8e7533da78cbc866d3e38e929dabb19..b109ddc4b8a6d9878eafefb4a4268bad4bc1975f 100644 --- a/crates/collab/src/db/project.rs +++ b/crates/collab/src/db/project.rs @@ -8,7 +8,7 @@ pub struct Model { pub id: ProjectId, pub room_id: RoomId, pub host_user_id: UserId, - pub host_connection_id: u32, + pub host_connection_id: i32, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/crates/collab/src/db/project_collaborator.rs b/crates/collab/src/db/project_collaborator.rs index 56048c318150e62c3d4bbe8eefcbb085cdf6153a..097272fcdafcff3bfa85bccca4abfa9570b5e508 100644 --- a/crates/collab/src/db/project_collaborator.rs +++ b/crates/collab/src/db/project_collaborator.rs @@ -7,7 +7,7 @@ pub struct Model { #[sea_orm(primary_key)] pub id: ProjectCollaboratorId, pub project_id: ProjectId, - pub connection_id: u32, + pub connection_id: i32, pub user_id: UserId, pub replica_id: ReplicaId, pub is_host: bool, diff --git a/crates/collab/src/db/room_participant.rs b/crates/collab/src/db/room_participant.rs index e8f38cf69318937fce7857891f7a2d0f4384c512..c7c804581b07be6825bbc27b44227d8da4a6b26a 100644 --- a/crates/collab/src/db/room_participant.rs +++ b/crates/collab/src/db/room_participant.rs @@ -8,12 +8,12 @@ pub struct Model { pub id: RoomParticipantId, pub room_id: RoomId, pub user_id: UserId, - pub answering_connection_id: Option, - pub location_kind: Option, + pub answering_connection_id: Option, + pub location_kind: Option, pub location_project_id: Option, pub initial_project_id: Option, pub calling_user_id: UserId, - pub calling_connection_id: u32, + pub calling_connection_id: i32, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index 6ca287746a7c9106bb1230c9eb9e1245e0e179b0..b276bd5057b7282815a4c21eeea00fd691eecff5 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -430,7 +430,7 @@ async fn test_fuzzy_search_users() { false, NewUserParams { github_login: github_login.into(), - github_user_id: i as u32, + github_user_id: i as i32, invite_count: 0, }, ) diff --git a/crates/collab/src/db/user.rs b/crates/collab/src/db/user.rs index 99292330447840124e1f471f75cc7de43f1560d3..c2b157bd0a758880fd6fe64b079fa8760b59df5c 100644 --- a/crates/collab/src/db/user.rs +++ b/crates/collab/src/db/user.rs @@ -8,11 +8,11 @@ pub struct Model { #[sea_orm(primary_key)] pub id: UserId, pub github_login: String, - pub github_user_id: Option, + pub github_user_id: Option, pub email_address: Option, pub admin: bool, pub invite_code: Option, - pub invite_count: u32, + pub invite_count: i32, pub inviter_id: Option, pub connected_once: bool, pub metrics_id: Uuid, diff --git a/crates/collab/src/db/worktree_diagnostic_summary.rs b/crates/collab/src/db/worktree_diagnostic_summary.rs index 49bf4f6e033f42247373732ad002f838e7ce68ad..f3dd8083fb57d9e863ea51de4e4de26b2d594a61 100644 --- a/crates/collab/src/db/worktree_diagnostic_summary.rs +++ b/crates/collab/src/db/worktree_diagnostic_summary.rs @@ -11,8 +11,8 @@ pub struct Model { #[sea_orm(primary_key)] pub path: String, pub language_server_id: i64, - pub error_count: u32, - pub warning_count: u32, + pub error_count: i32, + pub warning_count: i32, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/crates/collab/src/db/worktree_entry.rs b/crates/collab/src/db/worktree_entry.rs index f38ef7b3f78de8675e6f1486570607d09aca71db..413821201a20dd392713f43ec3b4163d7ff31f88 100644 --- a/crates/collab/src/db/worktree_entry.rs +++ b/crates/collab/src/db/worktree_entry.rs @@ -14,7 +14,7 @@ pub struct Model { pub path: String, pub inode: i64, pub mtime_seconds: i64, - pub mtime_nanos: u32, + pub mtime_nanos: i32, pub is_symlink: bool, pub is_ignored: bool, } diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index c5540284077317c9160a7278d00060bd068a41dd..225501c71d58fde0a3cbd9676c2ee9749dca3792 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -5383,7 +5383,7 @@ async fn test_random_collaboration( false, NewUserParams { github_login: username.clone(), - github_user_id: (ix + 1) as u32, + github_user_id: (ix + 1) as i32, invite_count: 0, }, ) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index d3b95a82e692fab62178612d431dd88f1aa30df8..9d3917a417ef4bde4f20b09771ab11fbdc26acfd 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -381,7 +381,7 @@ impl Server { if let Some((code, count)) = invite_code { this.peer.send(connection_id, proto::UpdateInviteInfo { url: format!("{}{}", this.app_state.config.invite_link_prefix, code), - count, + count: count as u32, })?; } } @@ -1008,7 +1008,7 @@ async fn join_project( let collaborators = project .collaborators .iter() - .filter(|collaborator| collaborator.connection_id != session.connection_id.0) + .filter(|collaborator| collaborator.connection_id != session.connection_id.0 as i32) .map(|collaborator| proto::Collaborator { peer_id: collaborator.connection_id as u32, replica_id: collaborator.replica_id.0 as u32, @@ -1313,7 +1313,8 @@ async fn save_buffer( .await .project_collaborators(project_id, session.connection_id) .await?; - collaborators.retain(|collaborator| collaborator.connection_id != session.connection_id.0); + collaborators + .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32); let project_connection_ids = collaborators .into_iter() .map(|collaborator| ConnectionId(collaborator.connection_id as u32)); From 7502558631e6cb301114b53fcc948da19b38b200 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 2 Dec 2022 14:22:36 +0100 Subject: [PATCH 094/109] Make all tests pass again after migration to sea-orm --- .../20221111092550_reconnection_support.sql | 10 +- crates/collab/src/db.rs | 150 ++++++++++-------- 2 files changed, 89 insertions(+), 71 deletions(-) diff --git a/crates/collab/migrations/20221111092550_reconnection_support.sql b/crates/collab/migrations/20221111092550_reconnection_support.sql index a7d45a9759d300624173edffdc4bd0f28d575c34..d23dbfa046942c22802a05a2bbe86ae600044f61 100644 --- a/crates/collab/migrations/20221111092550_reconnection_support.sql +++ b/crates/collab/migrations/20221111092550_reconnection_support.sql @@ -10,7 +10,7 @@ ALTER TABLE "projects" CREATE TABLE "worktrees" ( "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, - "id" INTEGER NOT NULL, + "id" INT8 NOT NULL, "root_name" VARCHAR NOT NULL, "abs_path" VARCHAR NOT NULL, "visible" BOOL NOT NULL, @@ -23,7 +23,7 @@ CREATE INDEX "index_worktrees_on_project_id" ON "worktrees" ("project_id"); CREATE TABLE "worktree_entries" ( "project_id" INTEGER NOT NULL, "worktree_id" INT8 NOT NULL, - "id" INTEGER NOT NULL, + "id" INT8 NOT NULL, "is_dir" BOOL NOT NULL, "path" VARCHAR NOT NULL, "inode" INT8 NOT NULL, @@ -39,9 +39,9 @@ CREATE INDEX "index_worktree_entries_on_project_id_and_worktree_id" ON "worktree CREATE TABLE "worktree_diagnostic_summaries" ( "project_id" INTEGER NOT NULL, - "worktree_id" INTEGER NOT NULL, + "worktree_id" INT8 NOT NULL, "path" VARCHAR NOT NULL, - "language_server_id" INTEGER NOT NULL, + "language_server_id" INT8 NOT NULL, "error_count" INTEGER NOT NULL, "warning_count" INTEGER NOT NULL, PRIMARY KEY(project_id, worktree_id, path), @@ -52,7 +52,7 @@ CREATE INDEX "index_worktree_diagnostic_summaries_on_project_id_and_worktree_id" CREATE TABLE "language_servers" ( "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, - "id" INTEGER NOT NULL, + "id" INT8 NOT NULL, "name" VARCHAR NOT NULL, PRIMARY KEY(project_id, id) ); diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 945ac1b577f31944b454597ee1713be8b37c18d9..7395a7cc769e7f72f053f5dcd2a0f2792b565011 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1494,17 +1494,21 @@ impl Database { .insert(&tx) .await?; - worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel { - id: ActiveValue::set(worktree.id as i64), - project_id: ActiveValue::set(project.id), - abs_path: ActiveValue::set(worktree.abs_path.clone()), - root_name: ActiveValue::set(worktree.root_name.clone()), - visible: ActiveValue::set(worktree.visible), - scan_id: ActiveValue::set(0), - is_complete: ActiveValue::set(false), - })) - .exec(&tx) - .await?; + if !worktrees.is_empty() { + worktree::Entity::insert_many(worktrees.iter().map(|worktree| { + worktree::ActiveModel { + id: ActiveValue::set(worktree.id as i64), + project_id: ActiveValue::set(project.id), + abs_path: ActiveValue::set(worktree.abs_path.clone()), + root_name: ActiveValue::set(worktree.root_name.clone()), + visible: ActiveValue::set(worktree.visible), + scan_id: ActiveValue::set(0), + is_complete: ActiveValue::set(false), + } + })) + .exec(&tx) + .await?; + } project_collaborator::ActiveModel { project_id: ActiveValue::set(project.id), @@ -1564,17 +1568,27 @@ impl Database { .await? .ok_or_else(|| anyhow!("no such project"))?; - worktree::Entity::insert_many(worktrees.iter().map(|worktree| worktree::ActiveModel { - id: ActiveValue::set(worktree.id as i64), - project_id: ActiveValue::set(project.id), - abs_path: ActiveValue::set(worktree.abs_path.clone()), - root_name: ActiveValue::set(worktree.root_name.clone()), - visible: ActiveValue::set(worktree.visible), - scan_id: ActiveValue::set(0), - is_complete: ActiveValue::set(false), - })) - .exec(&tx) - .await?; + if !worktrees.is_empty() { + worktree::Entity::insert_many(worktrees.iter().map(|worktree| { + worktree::ActiveModel { + id: ActiveValue::set(worktree.id as i64), + project_id: ActiveValue::set(project.id), + abs_path: ActiveValue::set(worktree.abs_path.clone()), + root_name: ActiveValue::set(worktree.root_name.clone()), + visible: ActiveValue::set(worktree.visible), + scan_id: ActiveValue::set(0), + is_complete: ActiveValue::set(false), + } + })) + .on_conflict( + OnConflict::columns([worktree::Column::ProjectId, worktree::Column::Id]) + .update_column(worktree::Column::RootName) + .to_owned(), + ) + .exec(&tx) + .await?; + } + worktree::Entity::delete_many() .filter( worktree::Column::ProjectId.eq(project.id).and( @@ -1623,53 +1637,57 @@ impl Database { .exec(&tx) .await?; - worktree_entry::Entity::insert_many(update.updated_entries.iter().map(|entry| { - let mtime = entry.mtime.clone().unwrap_or_default(); - worktree_entry::ActiveModel { - project_id: ActiveValue::set(project_id), - worktree_id: ActiveValue::set(worktree_id), - id: ActiveValue::set(entry.id as i64), - is_dir: ActiveValue::set(entry.is_dir), - path: ActiveValue::set(entry.path.clone()), - inode: ActiveValue::set(entry.inode as i64), - mtime_seconds: ActiveValue::set(mtime.seconds as i64), - mtime_nanos: ActiveValue::set(mtime.nanos as i32), - is_symlink: ActiveValue::set(entry.is_symlink), - is_ignored: ActiveValue::set(entry.is_ignored), - } - })) - .on_conflict( - OnConflict::columns([ - worktree_entry::Column::ProjectId, - worktree_entry::Column::WorktreeId, - worktree_entry::Column::Id, - ]) - .update_columns([ - worktree_entry::Column::IsDir, - worktree_entry::Column::Path, - worktree_entry::Column::Inode, - worktree_entry::Column::MtimeSeconds, - worktree_entry::Column::MtimeNanos, - worktree_entry::Column::IsSymlink, - worktree_entry::Column::IsIgnored, - ]) - .to_owned(), - ) - .exec(&tx) - .await?; - - worktree_entry::Entity::delete_many() - .filter( - worktree_entry::Column::ProjectId - .eq(project_id) - .and(worktree_entry::Column::WorktreeId.eq(worktree_id)) - .and( - worktree_entry::Column::Id - .is_in(update.removed_entries.iter().map(|id| *id as i64)), - ), + if !update.updated_entries.is_empty() { + worktree_entry::Entity::insert_many(update.updated_entries.iter().map(|entry| { + let mtime = entry.mtime.clone().unwrap_or_default(); + worktree_entry::ActiveModel { + project_id: ActiveValue::set(project_id), + worktree_id: ActiveValue::set(worktree_id), + id: ActiveValue::set(entry.id as i64), + is_dir: ActiveValue::set(entry.is_dir), + path: ActiveValue::set(entry.path.clone()), + inode: ActiveValue::set(entry.inode as i64), + mtime_seconds: ActiveValue::set(mtime.seconds as i64), + mtime_nanos: ActiveValue::set(mtime.nanos as i32), + is_symlink: ActiveValue::set(entry.is_symlink), + is_ignored: ActiveValue::set(entry.is_ignored), + } + })) + .on_conflict( + OnConflict::columns([ + worktree_entry::Column::ProjectId, + worktree_entry::Column::WorktreeId, + worktree_entry::Column::Id, + ]) + .update_columns([ + worktree_entry::Column::IsDir, + worktree_entry::Column::Path, + worktree_entry::Column::Inode, + worktree_entry::Column::MtimeSeconds, + worktree_entry::Column::MtimeNanos, + worktree_entry::Column::IsSymlink, + worktree_entry::Column::IsIgnored, + ]) + .to_owned(), ) .exec(&tx) .await?; + } + + if !update.removed_entries.is_empty() { + worktree_entry::Entity::delete_many() + .filter( + worktree_entry::Column::ProjectId + .eq(project_id) + .and(worktree_entry::Column::WorktreeId.eq(worktree_id)) + .and( + worktree_entry::Column::Id + .is_in(update.removed_entries.iter().map(|id| *id as i64)), + ), + ) + .exec(&tx) + .await?; + } let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?; self.commit_room_transaction(room_id, tx, connection_ids) From 1b46b7a7d6d14e24646ba1db46069ff6b63c9942 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 2 Dec 2022 14:37:26 +0100 Subject: [PATCH 095/109] Move modules into `collab` library as opposed to using the binary This ensures that we can use collab's modules from the seed script as well. --- crates/collab/src/bin/seed.rs | 10 ++--- crates/collab/src/lib.rs | 65 ++++++++++++++++++++++++++++++ crates/collab/src/main.rs | 75 +++-------------------------------- 3 files changed, 73 insertions(+), 77 deletions(-) diff --git a/crates/collab/src/bin/seed.rs b/crates/collab/src/bin/seed.rs index 3b635540b315bfbebe6058f9457e65237a0f1e3b..2f7c61147cbc84ddbeafb608452cb8b1daf2138e 100644 --- a/crates/collab/src/bin/seed.rs +++ b/crates/collab/src/bin/seed.rs @@ -1,12 +1,8 @@ -use collab::{Error, Result}; -use db::{DefaultDb, UserId}; +use collab::{db, Error, Result}; +use db::{ConnectOptions, Database, UserId}; use serde::{de::DeserializeOwned, Deserialize}; use std::fmt::Write; -#[allow(unused)] -#[path = "../db.rs"] -mod db; - #[derive(Debug, Deserialize)] struct GitHubUser { id: i32, @@ -17,7 +13,7 @@ struct GitHubUser { #[tokio::main] async fn main() { let database_url = std::env::var("DATABASE_URL").expect("missing DATABASE_URL env var"); - let db = DefaultDb::new(&database_url, 5) + let db = Database::new(ConnectOptions::new(database_url)) .await .expect("failed to connect to postgres database"); let github_token = std::env::var("GITHUB_TOKEN").expect("missing GITHUB_TOKEN env var"); diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 23af3344b55656781ea735d81287213186508c94..9011d2a1ebb7a88907037d2809b7289bf96e051a 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -1,4 +1,15 @@ +pub mod api; +pub mod auth; +pub mod db; +pub mod env; +#[cfg(test)] +mod integration_tests; +pub mod rpc; + use axum::{http::StatusCode, response::IntoResponse}; +use db::Database; +use serde::Deserialize; +use std::{path::PathBuf, sync::Arc}; pub type Result = std::result::Result; @@ -85,3 +96,57 @@ impl std::fmt::Display for Error { } impl std::error::Error for Error {} + +#[derive(Default, Deserialize)] +pub struct Config { + pub http_port: u16, + pub database_url: String, + pub api_token: String, + pub invite_link_prefix: String, + pub live_kit_server: Option, + pub live_kit_key: Option, + pub live_kit_secret: Option, + pub rust_log: Option, + pub log_json: Option, +} + +#[derive(Default, Deserialize)] +pub struct MigrateConfig { + pub database_url: String, + pub migrations_path: Option, +} + +pub struct AppState { + pub db: Arc, + pub live_kit_client: Option>, + pub config: Config, +} + +impl AppState { + pub async fn new(config: Config) -> Result> { + let mut db_options = db::ConnectOptions::new(config.database_url.clone()); + db_options.max_connections(5); + let db = Database::new(db_options).await?; + let live_kit_client = if let Some(((server, key), secret)) = config + .live_kit_server + .as_ref() + .zip(config.live_kit_key.as_ref()) + .zip(config.live_kit_secret.as_ref()) + { + Some(Arc::new(live_kit_server::api::LiveKitClient::new( + server.clone(), + key.clone(), + secret.clone(), + )) as Arc) + } else { + None + }; + + let this = Self { + db: Arc::new(db), + live_kit_client, + config, + }; + Ok(Arc::new(this)) + } +} diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 4802fd82b41f5f0a069da7168a683cc7ab46e641..42ffe50ea3da084bcfbc91cb7da5fb71505283e9 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -1,22 +1,11 @@ -mod api; -mod auth; -mod db; -mod env; -mod rpc; - -#[cfg(test)] -mod integration_tests; - use anyhow::anyhow; use axum::{routing::get, Router}; -use collab::{Error, Result}; +use collab::{db, env, AppState, Config, MigrateConfig, Result}; use db::Database; -use serde::Deserialize; use std::{ env::args, net::{SocketAddr, TcpListener}, - path::{Path, PathBuf}, - sync::Arc, + path::Path, }; use tracing_log::LogTracer; use tracing_subscriber::{filter::EnvFilter, fmt::format::JsonFields, Layer}; @@ -24,60 +13,6 @@ use util::ResultExt; const VERSION: &'static str = env!("CARGO_PKG_VERSION"); -#[derive(Default, Deserialize)] -pub struct Config { - pub http_port: u16, - pub database_url: String, - pub api_token: String, - pub invite_link_prefix: String, - pub live_kit_server: Option, - pub live_kit_key: Option, - pub live_kit_secret: Option, - pub rust_log: Option, - pub log_json: Option, -} - -#[derive(Default, Deserialize)] -pub struct MigrateConfig { - pub database_url: String, - pub migrations_path: Option, -} - -pub struct AppState { - db: Arc, - live_kit_client: Option>, - config: Config, -} - -impl AppState { - async fn new(config: Config) -> Result> { - let mut db_options = db::ConnectOptions::new(config.database_url.clone()); - db_options.max_connections(5); - let db = Database::new(db_options).await?; - let live_kit_client = if let Some(((server, key), secret)) = config - .live_kit_server - .as_ref() - .zip(config.live_kit_key.as_ref()) - .zip(config.live_kit_secret.as_ref()) - { - Some(Arc::new(live_kit_server::api::LiveKitClient::new( - server.clone(), - key.clone(), - secret.clone(), - )) as Arc) - } else { - None - }; - - let this = Self { - db: Arc::new(db), - live_kit_client, - config, - }; - Ok(Arc::new(this)) - } -} - #[tokio::main] async fn main() -> Result<()> { if let Err(error) = env::load_dotenv() { @@ -120,10 +55,10 @@ async fn main() -> Result<()> { let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port)) .expect("failed to bind TCP listener"); - let rpc_server = rpc::Server::new(state.clone()); + let rpc_server = collab::rpc::Server::new(state.clone()); - let app = api::routes(rpc_server.clone(), state.clone()) - .merge(rpc::routes(rpc_server.clone())) + let app = collab::api::routes(rpc_server.clone(), state.clone()) + .merge(collab::rpc::routes(rpc_server.clone())) .merge(Router::new().route("/", get(handle_root))); axum::Server::from_tcp(listener)? From 27f6ae945d2c53fe367c87672162913a5aef3baa Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 2 Dec 2022 16:30:00 +0100 Subject: [PATCH 096/109] Clear stale data on startup This is a stopgap measure until we introduce reconnection support. --- .../20221109000000_test_schema.sql | 12 +++++-- .../20221111092550_reconnection_support.sql | 10 +++++- crates/collab/src/bin/seed.rs | 2 +- crates/collab/src/db.rs | 33 +++++++++++++++++++ crates/collab/src/db/project.rs | 1 + crates/collab/src/db/project_collaborator.rs | 1 + crates/collab/src/db/room_participant.rs | 2 ++ crates/collab/src/main.rs | 2 ++ 8 files changed, 59 insertions(+), 4 deletions(-) diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index e62f834fbf07ca4d4265e1f7d710323413193e64..347db6a71a8d44f21f5cfcac7c3c73a1c67856c9 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -43,8 +43,10 @@ CREATE TABLE "projects" ( "id" INTEGER PRIMARY KEY, "room_id" INTEGER REFERENCES rooms (id) NOT NULL, "host_user_id" INTEGER REFERENCES users (id) NOT NULL, - "host_connection_id" INTEGER NOT NULL + "host_connection_id" INTEGER NOT NULL, + "host_connection_epoch" TEXT NOT NULL ); +CREATE INDEX "index_projects_on_host_connection_epoch" ON "projects" ("host_connection_epoch"); CREATE TABLE "worktrees" ( "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, @@ -100,22 +102,28 @@ CREATE TABLE "project_collaborators" ( "id" INTEGER PRIMARY KEY, "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, "connection_id" INTEGER NOT NULL, + "connection_epoch" TEXT NOT NULL, "user_id" INTEGER NOT NULL, "replica_id" INTEGER NOT NULL, "is_host" BOOLEAN NOT NULL ); CREATE INDEX "index_project_collaborators_on_project_id" ON "project_collaborators" ("project_id"); CREATE UNIQUE INDEX "index_project_collaborators_on_project_id_and_replica_id" ON "project_collaborators" ("project_id", "replica_id"); +CREATE INDEX "index_project_collaborators_on_connection_epoch" ON "project_collaborators" ("connection_epoch"); CREATE TABLE "room_participants" ( "id" INTEGER PRIMARY KEY, "room_id" INTEGER NOT NULL REFERENCES rooms (id), "user_id" INTEGER NOT NULL REFERENCES users (id), "answering_connection_id" INTEGER, + "answering_connection_epoch" TEXT, "location_kind" INTEGER, "location_project_id" INTEGER REFERENCES projects (id), "initial_project_id" INTEGER REFERENCES projects (id), "calling_user_id" INTEGER NOT NULL REFERENCES users (id), - "calling_connection_id" INTEGER NOT NULL + "calling_connection_id" INTEGER NOT NULL, + "calling_connection_epoch" TEXT NOT NULL ); CREATE UNIQUE INDEX "index_room_participants_on_user_id" ON "room_participants" ("user_id"); +CREATE INDEX "index_room_participants_on_answering_connection_epoch" ON "room_participants" ("answering_connection_epoch"); +CREATE INDEX "index_room_participants_on_calling_connection_epoch" ON "room_participants" ("calling_connection_epoch"); diff --git a/crates/collab/migrations/20221111092550_reconnection_support.sql b/crates/collab/migrations/20221111092550_reconnection_support.sql index d23dbfa046942c22802a05a2bbe86ae600044f61..6278fa7a595b05cb7adbf97f622b06f675116af3 100644 --- a/crates/collab/migrations/20221111092550_reconnection_support.sql +++ b/crates/collab/migrations/20221111092550_reconnection_support.sql @@ -6,7 +6,9 @@ CREATE TABLE IF NOT EXISTS "rooms" ( ALTER TABLE "projects" ADD "room_id" INTEGER REFERENCES rooms (id), ADD "host_connection_id" INTEGER, + ADD "host_connection_epoch" UUID, DROP COLUMN "unregistered"; +CREATE INDEX "index_projects_on_host_connection_epoch" ON "projects" ("host_connection_epoch"); CREATE TABLE "worktrees" ( "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, @@ -62,22 +64,28 @@ CREATE TABLE "project_collaborators" ( "id" SERIAL PRIMARY KEY, "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE, "connection_id" INTEGER NOT NULL, + "connection_epoch" UUID NOT NULL, "user_id" INTEGER NOT NULL, "replica_id" INTEGER NOT NULL, "is_host" BOOLEAN NOT NULL ); CREATE INDEX "index_project_collaborators_on_project_id" ON "project_collaborators" ("project_id"); CREATE UNIQUE INDEX "index_project_collaborators_on_project_id_and_replica_id" ON "project_collaborators" ("project_id", "replica_id"); +CREATE INDEX "index_project_collaborators_on_connection_epoch" ON "project_collaborators" ("connection_epoch"); CREATE TABLE "room_participants" ( "id" SERIAL PRIMARY KEY, "room_id" INTEGER NOT NULL REFERENCES rooms (id), "user_id" INTEGER NOT NULL REFERENCES users (id), "answering_connection_id" INTEGER, + "answering_connection_epoch" UUID, "location_kind" INTEGER, "location_project_id" INTEGER REFERENCES projects (id), "initial_project_id" INTEGER REFERENCES projects (id), "calling_user_id" INTEGER NOT NULL REFERENCES users (id), - "calling_connection_id" INTEGER NOT NULL + "calling_connection_id" INTEGER NOT NULL, + "calling_connection_epoch" UUID NOT NULL ); CREATE UNIQUE INDEX "index_room_participants_on_user_id" ON "room_participants" ("user_id"); +CREATE INDEX "index_room_participants_on_answering_connection_epoch" ON "room_participants" ("answering_connection_epoch"); +CREATE INDEX "index_room_participants_on_calling_connection_epoch" ON "room_participants" ("calling_connection_epoch"); diff --git a/crates/collab/src/bin/seed.rs b/crates/collab/src/bin/seed.rs index 2f7c61147cbc84ddbeafb608452cb8b1daf2138e..9860b8be845a360ba4477b7fb48527e4050f491f 100644 --- a/crates/collab/src/bin/seed.rs +++ b/crates/collab/src/bin/seed.rs @@ -1,4 +1,4 @@ -use collab::{db, Error, Result}; +use collab::db; use db::{ConnectOptions, Database, UserId}; use serde::{de::DeserializeOwned, Deserialize}; use std::fmt::Write; diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 7395a7cc769e7f72f053f5dcd2a0f2792b565011..05d62741089ccc2c8685706584cb6c0f4fcdfa63 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -47,6 +47,7 @@ pub struct Database { background: Option>, #[cfg(test)] runtime: Option, + epoch: Uuid, } impl Database { @@ -59,6 +60,7 @@ impl Database { background: None, #[cfg(test)] runtime: None, + epoch: Uuid::new_v4(), }) } @@ -103,6 +105,30 @@ impl Database { Ok(new_migrations) } + pub async fn clear_stale_data(&self) -> Result<()> { + self.transact(|tx| async { + project_collaborator::Entity::delete_many() + .filter(project_collaborator::Column::ConnectionEpoch.ne(self.epoch)) + .exec(&tx) + .await?; + room_participant::Entity::delete_many() + .filter( + room_participant::Column::AnsweringConnectionEpoch + .ne(self.epoch) + .or(room_participant::Column::CallingConnectionEpoch.ne(self.epoch)), + ) + .exec(&tx) + .await?; + project::Entity::delete_many() + .filter(project::Column::HostConnectionEpoch.ne(self.epoch)) + .exec(&tx) + .await?; + tx.commit().await?; + Ok(()) + }) + .await + } + // users pub async fn create_user( @@ -983,8 +1009,10 @@ impl Database { room_id: ActiveValue::set(room_id), user_id: ActiveValue::set(user_id), answering_connection_id: ActiveValue::set(Some(connection_id.0 as i32)), + answering_connection_epoch: ActiveValue::set(Some(self.epoch)), calling_user_id: ActiveValue::set(user_id), calling_connection_id: ActiveValue::set(connection_id.0 as i32), + calling_connection_epoch: ActiveValue::set(self.epoch), ..Default::default() } .insert(&tx) @@ -1010,6 +1038,7 @@ impl Database { user_id: ActiveValue::set(called_user_id), calling_user_id: ActiveValue::set(calling_user_id), calling_connection_id: ActiveValue::set(calling_connection_id.0 as i32), + calling_connection_epoch: ActiveValue::set(self.epoch), initial_project_id: ActiveValue::set(initial_project_id), ..Default::default() } @@ -1127,6 +1156,7 @@ impl Database { ) .set(room_participant::ActiveModel { answering_connection_id: ActiveValue::set(Some(connection_id.0 as i32)), + answering_connection_epoch: ActiveValue::set(Some(self.epoch)), ..Default::default() }) .exec(&tx) @@ -1489,6 +1519,7 @@ impl Database { room_id: ActiveValue::set(participant.room_id), host_user_id: ActiveValue::set(participant.user_id), host_connection_id: ActiveValue::set(connection_id.0 as i32), + host_connection_epoch: ActiveValue::set(self.epoch), ..Default::default() } .insert(&tx) @@ -1513,6 +1544,7 @@ impl Database { project_collaborator::ActiveModel { project_id: ActiveValue::set(project.id), connection_id: ActiveValue::set(connection_id.0 as i32), + connection_epoch: ActiveValue::set(self.epoch), user_id: ActiveValue::set(participant.user_id), replica_id: ActiveValue::set(ReplicaId(0)), is_host: ActiveValue::set(true), @@ -1832,6 +1864,7 @@ impl Database { let new_collaborator = project_collaborator::ActiveModel { project_id: ActiveValue::set(project_id), connection_id: ActiveValue::set(connection_id.0 as i32), + connection_epoch: ActiveValue::set(self.epoch), user_id: ActiveValue::set(participant.user_id), replica_id: ActiveValue::set(replica_id), is_host: ActiveValue::set(false), diff --git a/crates/collab/src/db/project.rs b/crates/collab/src/db/project.rs index b109ddc4b8a6d9878eafefb4a4268bad4bc1975f..971a8fcefb465114c9703003e2a74f6f38d8c397 100644 --- a/crates/collab/src/db/project.rs +++ b/crates/collab/src/db/project.rs @@ -9,6 +9,7 @@ pub struct Model { pub room_id: RoomId, pub host_user_id: UserId, pub host_connection_id: i32, + pub host_connection_epoch: Uuid, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/crates/collab/src/db/project_collaborator.rs b/crates/collab/src/db/project_collaborator.rs index 097272fcdafcff3bfa85bccca4abfa9570b5e508..5db307f5df27ec07f282207eb253ddae95c44970 100644 --- a/crates/collab/src/db/project_collaborator.rs +++ b/crates/collab/src/db/project_collaborator.rs @@ -8,6 +8,7 @@ pub struct Model { pub id: ProjectCollaboratorId, pub project_id: ProjectId, pub connection_id: i32, + pub connection_epoch: Uuid, pub user_id: UserId, pub replica_id: ReplicaId, pub is_host: bool, diff --git a/crates/collab/src/db/room_participant.rs b/crates/collab/src/db/room_participant.rs index c7c804581b07be6825bbc27b44227d8da4a6b26a..783f45aa93e1952be3f5dd2f5efd0d51da6665cd 100644 --- a/crates/collab/src/db/room_participant.rs +++ b/crates/collab/src/db/room_participant.rs @@ -9,11 +9,13 @@ pub struct Model { pub room_id: RoomId, pub user_id: UserId, pub answering_connection_id: Option, + pub answering_connection_epoch: Option, pub location_kind: Option, pub location_project_id: Option, pub initial_project_id: Option, pub calling_user_id: UserId, pub calling_connection_id: i32, + pub calling_connection_epoch: Uuid, } #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] diff --git a/crates/collab/src/main.rs b/crates/collab/src/main.rs index 42ffe50ea3da084bcfbc91cb7da5fb71505283e9..a288e0f3ce83fe8c7a0656f108f15c6088021d68 100644 --- a/crates/collab/src/main.rs +++ b/crates/collab/src/main.rs @@ -52,6 +52,8 @@ async fn main() -> Result<()> { init_tracing(&config); let state = AppState::new(config).await?; + state.db.clear_stale_data().await?; + let listener = TcpListener::bind(&format!("0.0.0.0:{}", state.config.http_port)) .expect("failed to bind TCP listener"); From 568de814aad478ccca1e792e83cb24ca7fea3172 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 2 Dec 2022 16:52:48 +0100 Subject: [PATCH 097/109] Delete empty rooms --- crates/collab/src/db.rs | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 05d62741089ccc2c8685706584cb6c0f4fcdfa63..ea9757a973fa2e19d1130c9b8af3f54ede2ad504 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -123,6 +123,18 @@ impl Database { .filter(project::Column::HostConnectionEpoch.ne(self.epoch)) .exec(&tx) .await?; + room::Entity::delete_many() + .filter( + room::Column::Id.not_in_subquery( + Query::select() + .column(room_participant::Column::RoomId) + .from(room_participant::Entity) + .distinct() + .to_owned(), + ), + ) + .exec(&tx) + .await?; tx.commit().await?; Ok(()) }) @@ -1272,8 +1284,12 @@ impl Database { .await?; let room = self.get_room(room_id, &tx).await?; - Ok(Some( - self.commit_room_transaction( + if room.participants.is_empty() { + room::Entity::delete_by_id(room_id).exec(&tx).await?; + } + + let left_room = self + .commit_room_transaction( room_id, tx, LeftRoom { @@ -1282,8 +1298,13 @@ impl Database { canceled_calls_to_user_ids, }, ) - .await?, - )) + .await?; + + if left_room.room.participants.is_empty() { + self.rooms.remove(&room_id); + } + + Ok(Some(left_room)) } else { Ok(None) } From 1c30767592b2f204c70189f0a80580f7cbee8016 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 2 Dec 2022 19:20:51 +0100 Subject: [PATCH 098/109] Remove stale `Error` variant Co-Authored-By: Max Brunsfeld --- crates/collab/src/db.rs | 2 +- crates/collab/src/lib.rs | 16 ++-------------- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 2a8163b9c82d5636ceec27b92284c38e72c8f277..fd1ed7d50ffb196149776f5232faf03e743d7137 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -2203,7 +2203,7 @@ impl Database { match f(tx).await { Ok(result) => return Ok(result), Err(error) => match error { - Error::Database2( + Error::Database( DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error)) | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)), ) if error diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 9011d2a1ebb7a88907037d2809b7289bf96e051a..24a9fc6117ce81ea493b742c2c6f7cbd6e8ca5d4 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -15,8 +15,7 @@ pub type Result = std::result::Result; pub enum Error { Http(StatusCode, String), - Database(sqlx::Error), - Database2(sea_orm::error::DbErr), + Database(sea_orm::error::DbErr), Internal(anyhow::Error), } @@ -26,15 +25,9 @@ impl From for Error { } } -impl From for Error { - fn from(error: sqlx::Error) -> Self { - Self::Database(error) - } -} - impl From for Error { fn from(error: sea_orm::error::DbErr) -> Self { - Self::Database2(error) + Self::Database(error) } } @@ -63,9 +56,6 @@ impl IntoResponse for Error { Error::Database(error) => { (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() } - Error::Database2(error) => { - (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() - } Error::Internal(error) => { (StatusCode::INTERNAL_SERVER_ERROR, format!("{}", &error)).into_response() } @@ -78,7 +68,6 @@ impl std::fmt::Debug for Error { match self { Error::Http(code, message) => (code, message).fmt(f), Error::Database(error) => error.fmt(f), - Error::Database2(error) => error.fmt(f), Error::Internal(error) => error.fmt(f), } } @@ -89,7 +78,6 @@ impl std::fmt::Display for Error { match self { Error::Http(code, message) => write!(f, "{code}: {message}"), Error::Database(error) => error.fmt(f), - Error::Database2(error) => error.fmt(f), Error::Internal(error) => error.fmt(f), } } From d96f524fb6bb8873d8baaf96fcca7f690372fc53 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Fri, 2 Dec 2022 20:36:50 +0100 Subject: [PATCH 099/109] WIP: Manually rollback transactions to avoid spurious savepoint failure TODO: - Avoid unwrapping transaction after f(tx) - Remove duplication between `transaction` and `room_transaction` - Introduce random delay before and after committing a transaction - Run lots of randomized tests - Investigate diverging diagnostic summaries Co-Authored-By: Max Brunsfeld --- crates/collab/src/db.rs | 565 +++++++++++++++++++++------------------ crates/collab/src/rpc.rs | 4 +- 2 files changed, 300 insertions(+), 269 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index fd1ed7d50ffb196149776f5232faf03e743d7137..e667930cad2953d1379f9aa07389202f16ff2219 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -106,10 +106,10 @@ impl Database { } pub async fn clear_stale_data(&self) -> Result<()> { - self.transact(|tx| async { + self.transaction(|tx| async move { project_collaborator::Entity::delete_many() .filter(project_collaborator::Column::ConnectionEpoch.ne(self.epoch)) - .exec(&tx) + .exec(&*tx) .await?; room_participant::Entity::delete_many() .filter( @@ -117,11 +117,11 @@ impl Database { .ne(self.epoch) .or(room_participant::Column::CallingConnectionEpoch.ne(self.epoch)), ) - .exec(&tx) + .exec(&*tx) .await?; project::Entity::delete_many() .filter(project::Column::HostConnectionEpoch.ne(self.epoch)) - .exec(&tx) + .exec(&*tx) .await?; room::Entity::delete_many() .filter( @@ -133,9 +133,8 @@ impl Database { .to_owned(), ), ) - .exec(&tx) + .exec(&*tx) .await?; - tx.commit().await?; Ok(()) }) .await @@ -149,7 +148,8 @@ impl Database { admin: bool, params: NewUserParams, ) -> Result { - self.transact(|tx| async { + self.transaction(|tx| async { + let tx = tx; let user = user::Entity::insert(user::ActiveModel { email_address: ActiveValue::set(Some(email_address.into())), github_login: ActiveValue::set(params.github_login.clone()), @@ -163,11 +163,9 @@ impl Database { .update_column(user::Column::GithubLogin) .to_owned(), ) - .exec_with_returning(&tx) + .exec_with_returning(&*tx) .await?; - tx.commit().await?; - Ok(NewUserResult { user_id: user.id, metrics_id: user.metrics_id.to_string(), @@ -179,16 +177,16 @@ impl Database { } pub async fn get_user_by_id(&self, id: UserId) -> Result> { - self.transact(|tx| async move { Ok(user::Entity::find_by_id(id).one(&tx).await?) }) + self.transaction(|tx| async move { Ok(user::Entity::find_by_id(id).one(&*tx).await?) }) .await } pub async fn get_users_by_ids(&self, ids: Vec) -> Result> { - self.transact(|tx| async { + self.transaction(|tx| async { let tx = tx; Ok(user::Entity::find() .filter(user::Column::Id.is_in(ids.iter().copied())) - .all(&tx) + .all(&*tx) .await?) }) .await @@ -199,32 +197,32 @@ impl Database { github_login: &str, github_user_id: Option, ) -> Result> { - self.transact(|tx| async { - let tx = tx; + self.transaction(|tx| async move { + let tx = &*tx; if let Some(github_user_id) = github_user_id { if let Some(user_by_github_user_id) = user::Entity::find() .filter(user::Column::GithubUserId.eq(github_user_id)) - .one(&tx) + .one(tx) .await? { let mut user_by_github_user_id = user_by_github_user_id.into_active_model(); user_by_github_user_id.github_login = ActiveValue::set(github_login.into()); - Ok(Some(user_by_github_user_id.update(&tx).await?)) + Ok(Some(user_by_github_user_id.update(tx).await?)) } else if let Some(user_by_github_login) = user::Entity::find() .filter(user::Column::GithubLogin.eq(github_login)) - .one(&tx) + .one(tx) .await? { let mut user_by_github_login = user_by_github_login.into_active_model(); user_by_github_login.github_user_id = ActiveValue::set(Some(github_user_id)); - Ok(Some(user_by_github_login.update(&tx).await?)) + Ok(Some(user_by_github_login.update(tx).await?)) } else { Ok(None) } } else { Ok(user::Entity::find() .filter(user::Column::GithubLogin.eq(github_login)) - .one(&tx) + .one(tx) .await?) } }) @@ -232,12 +230,12 @@ impl Database { } pub async fn get_all_users(&self, page: u32, limit: u32) -> Result> { - self.transact(|tx| async move { + self.transaction(|tx| async move { Ok(user::Entity::find() .order_by_asc(user::Column::GithubLogin) .limit(limit as u64) .offset(page as u64 * limit as u64) - .all(&tx) + .all(&*tx) .await?) }) .await @@ -247,7 +245,7 @@ impl Database { &self, invited_by_another_user: bool, ) -> Result> { - self.transact(|tx| async move { + self.transaction(|tx| async move { Ok(user::Entity::find() .filter( user::Column::InviteCount @@ -258,7 +256,7 @@ impl Database { user::Column::InviterId.is_null() }), ) - .all(&tx) + .all(&*tx) .await?) }) .await @@ -270,12 +268,12 @@ impl Database { MetricsId, } - self.transact(|tx| async move { + self.transaction(|tx| async move { let metrics_id: Uuid = user::Entity::find_by_id(id) .select_only() .column(user::Column::MetricsId) .into_values::<_, QueryAs>() - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("could not find user"))?; Ok(metrics_id.to_string()) @@ -284,45 +282,42 @@ impl Database { } pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> { - self.transact(|tx| async move { + self.transaction(|tx| async move { user::Entity::update_many() .filter(user::Column::Id.eq(id)) .set(user::ActiveModel { admin: ActiveValue::set(is_admin), ..Default::default() }) - .exec(&tx) + .exec(&*tx) .await?; - tx.commit().await?; Ok(()) }) .await } pub async fn set_user_connected_once(&self, id: UserId, connected_once: bool) -> Result<()> { - self.transact(|tx| async move { + self.transaction(|tx| async move { user::Entity::update_many() .filter(user::Column::Id.eq(id)) .set(user::ActiveModel { connected_once: ActiveValue::set(connected_once), ..Default::default() }) - .exec(&tx) + .exec(&*tx) .await?; - tx.commit().await?; Ok(()) }) .await } pub async fn destroy_user(&self, id: UserId) -> Result<()> { - self.transact(|tx| async move { + self.transaction(|tx| async move { access_token::Entity::delete_many() .filter(access_token::Column::UserId.eq(id)) - .exec(&tx) + .exec(&*tx) .await?; - user::Entity::delete_by_id(id).exec(&tx).await?; - tx.commit().await?; + user::Entity::delete_by_id(id).exec(&*tx).await?; Ok(()) }) .await @@ -342,7 +337,7 @@ impl Database { user_b_busy: bool, } - self.transact(|tx| async move { + self.transaction(|tx| async move { let user_a_participant = Alias::new("user_a_participant"); let user_b_participant = Alias::new("user_b_participant"); let mut db_contacts = contact::Entity::find() @@ -372,7 +367,7 @@ impl Database { user_b_participant, ) .into_model::() - .stream(&tx) + .stream(&*tx) .await?; let mut contacts = Vec::new(); @@ -421,10 +416,10 @@ impl Database { } pub async fn is_user_busy(&self, user_id: UserId) -> Result { - self.transact(|tx| async move { + self.transaction(|tx| async move { let participant = room_participant::Entity::find() .filter(room_participant::Column::UserId.eq(user_id)) - .one(&tx) + .one(&*tx) .await?; Ok(participant.is_some()) }) @@ -432,7 +427,7 @@ impl Database { } pub async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result { - self.transact(|tx| async move { + self.transaction(|tx| async move { let (id_a, id_b) = if user_id_1 < user_id_2 { (user_id_1, user_id_2) } else { @@ -446,7 +441,7 @@ impl Database { .and(contact::Column::UserIdB.eq(id_b)) .and(contact::Column::Accepted.eq(true)), ) - .one(&tx) + .one(&*tx) .await? .is_some()) }) @@ -454,7 +449,7 @@ impl Database { } pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { - self.transact(|tx| async move { + self.transaction(|tx| async move { let (id_a, id_b, a_to_b) = if sender_id < receiver_id { (sender_id, receiver_id, true) } else { @@ -487,11 +482,10 @@ impl Database { ) .to_owned(), ) - .exec_without_returning(&tx) + .exec_without_returning(&*tx) .await?; if rows_affected == 1 { - tx.commit().await?; Ok(()) } else { Err(anyhow!("contact already requested"))? @@ -501,7 +495,7 @@ impl Database { } pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()> { - self.transact(|tx| async move { + self.transaction(|tx| async move { let (id_a, id_b) = if responder_id < requester_id { (responder_id, requester_id) } else { @@ -514,11 +508,10 @@ impl Database { .eq(id_a) .and(contact::Column::UserIdB.eq(id_b)), ) - .exec(&tx) + .exec(&*tx) .await?; if result.rows_affected == 1 { - tx.commit().await?; Ok(()) } else { Err(anyhow!("no such contact"))? @@ -532,7 +525,7 @@ impl Database { user_id: UserId, contact_user_id: UserId, ) -> Result<()> { - self.transact(|tx| async move { + self.transaction(|tx| async move { let (id_a, id_b, a_to_b) = if user_id < contact_user_id { (user_id, contact_user_id, true) } else { @@ -557,12 +550,11 @@ impl Database { .and(contact::Column::Accepted.eq(false))), ), ) - .exec(&tx) + .exec(&*tx) .await?; if result.rows_affected == 0 { Err(anyhow!("no such contact request"))? } else { - tx.commit().await?; Ok(()) } }) @@ -575,7 +567,7 @@ impl Database { requester_id: UserId, accept: bool, ) -> Result<()> { - self.transact(|tx| async move { + self.transaction(|tx| async move { let (id_a, id_b, a_to_b) = if responder_id < requester_id { (responder_id, requester_id, false) } else { @@ -594,7 +586,7 @@ impl Database { .and(contact::Column::UserIdB.eq(id_b)) .and(contact::Column::AToB.eq(a_to_b)), ) - .exec(&tx) + .exec(&*tx) .await?; result.rows_affected } else { @@ -606,14 +598,13 @@ impl Database { .and(contact::Column::AToB.eq(a_to_b)) .and(contact::Column::Accepted.eq(false)), ) - .exec(&tx) + .exec(&*tx) .await?; result.rows_affected }; if rows_affected == 1 { - tx.commit().await?; Ok(()) } else { Err(anyhow!("no such contact request"))? @@ -635,7 +626,7 @@ impl Database { } pub async fn fuzzy_search_users(&self, name_query: &str, limit: u32) -> Result> { - self.transact(|tx| async { + self.transaction(|tx| async { let tx = tx; let like_string = Self::fuzzy_like_string(name_query); let query = " @@ -652,7 +643,7 @@ impl Database { query.into(), vec![like_string.into(), name_query.into(), limit.into()], )) - .all(&tx) + .all(&*tx) .await?) }) .await @@ -661,7 +652,7 @@ impl Database { // signups pub async fn create_signup(&self, signup: &NewSignup) -> Result<()> { - self.transact(|tx| async { + self.transaction(|tx| async move { signup::Entity::insert(signup::ActiveModel { email_address: ActiveValue::set(signup.email_address.clone()), email_confirmation_code: ActiveValue::set(random_email_confirmation_code()), @@ -681,16 +672,15 @@ impl Database { .update_column(signup::Column::EmailAddress) .to_owned(), ) - .exec(&tx) + .exec(&*tx) .await?; - tx.commit().await?; Ok(()) }) .await } pub async fn get_waitlist_summary(&self) -> Result { - self.transact(|tx| async move { + self.transaction(|tx| async move { let query = " SELECT COUNT(*) as count, @@ -711,7 +701,7 @@ impl Database { query.into(), vec![], )) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("invalid result"))?, ) @@ -724,23 +714,23 @@ impl Database { .iter() .map(|s| s.email_address.as_str()) .collect::>(); - self.transact(|tx| async { + self.transaction(|tx| async { + let tx = tx; signup::Entity::update_many() .filter(signup::Column::EmailAddress.is_in(emails.iter().copied())) .set(signup::ActiveModel { email_confirmation_sent: ActiveValue::set(true), ..Default::default() }) - .exec(&tx) + .exec(&*tx) .await?; - tx.commit().await?; Ok(()) }) .await } pub async fn get_unsent_invites(&self, count: usize) -> Result> { - self.transact(|tx| async move { + self.transaction(|tx| async move { Ok(signup::Entity::find() .select_only() .column(signup::Column::EmailAddress) @@ -755,7 +745,7 @@ impl Database { .order_by_asc(signup::Column::CreatedAt) .limit(count as u64) .into_model() - .all(&tx) + .all(&*tx) .await?) }) .await @@ -769,10 +759,10 @@ impl Database { email_address: &str, device_id: Option<&str>, ) -> Result { - self.transact(|tx| async move { + self.transaction(|tx| async move { let existing_user = user::Entity::find() .filter(user::Column::EmailAddress.eq(email_address)) - .one(&tx) + .one(&*tx) .await?; if existing_user.is_some() { @@ -785,7 +775,7 @@ impl Database { .eq(code) .and(user::Column::InviteCount.gt(0)), ) - .one(&tx) + .one(&*tx) .await? { Some(inviting_user) => inviting_user, @@ -806,7 +796,7 @@ impl Database { user::Column::InviteCount, Expr::col(user::Column::InviteCount).sub(1), ) - .exec(&tx) + .exec(&*tx) .await?; let signup = signup::Entity::insert(signup::ActiveModel { @@ -826,9 +816,8 @@ impl Database { .update_column(signup::Column::InvitingUserId) .to_owned(), ) - .exec_with_returning(&tx) + .exec_with_returning(&*tx) .await?; - tx.commit().await?; Ok(Invite { email_address: signup.email_address, @@ -843,7 +832,7 @@ impl Database { invite: &Invite, user: NewUserParams, ) -> Result> { - self.transact(|tx| async { + self.transaction(|tx| async { let tx = tx; let signup = signup::Entity::find() .filter( @@ -854,7 +843,7 @@ impl Database { .eq(invite.email_confirmation_code.as_str()), ), ) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| Error::Http(StatusCode::NOT_FOUND, "no such invite".to_string()))?; @@ -881,12 +870,12 @@ impl Database { ]) .to_owned(), ) - .exec_with_returning(&tx) + .exec_with_returning(&*tx) .await?; let mut signup = signup.into_active_model(); signup.user_id = ActiveValue::set(Some(user.id)); - let signup = signup.update(&tx).await?; + let signup = signup.update(&*tx).await?; if let Some(inviting_user_id) = signup.inviting_user_id { contact::Entity::insert(contact::ActiveModel { @@ -898,11 +887,10 @@ impl Database { ..Default::default() }) .on_conflict(OnConflict::new().do_nothing().to_owned()) - .exec_without_returning(&tx) + .exec_without_returning(&*tx) .await?; } - tx.commit().await?; Ok(Some(NewUserResult { user_id: user.id, metrics_id: user.metrics_id.to_string(), @@ -914,7 +902,7 @@ impl Database { } pub async fn set_invite_count_for_user(&self, id: UserId, count: i32) -> Result<()> { - self.transact(|tx| async move { + self.transaction(|tx| async move { if count > 0 { user::Entity::update_many() .filter( @@ -926,7 +914,7 @@ impl Database { invite_code: ActiveValue::set(Some(random_invite_code())), ..Default::default() }) - .exec(&tx) + .exec(&*tx) .await?; } @@ -936,17 +924,16 @@ impl Database { invite_count: ActiveValue::set(count), ..Default::default() }) - .exec(&tx) + .exec(&*tx) .await?; - tx.commit().await?; Ok(()) }) .await } pub async fn get_invite_code_for_user(&self, id: UserId) -> Result> { - self.transact(|tx| async move { - match user::Entity::find_by_id(id).one(&tx).await? { + self.transaction(|tx| async move { + match user::Entity::find_by_id(id).one(&*tx).await? { Some(user) if user.invite_code.is_some() => { Ok(Some((user.invite_code.unwrap(), user.invite_count))) } @@ -957,10 +944,10 @@ impl Database { } pub async fn get_user_for_invite_code(&self, code: &str) -> Result { - self.transact(|tx| async move { + self.transaction(|tx| async move { user::Entity::find() .filter(user::Column::InviteCode.eq(code)) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| { Error::Http( @@ -978,14 +965,14 @@ impl Database { &self, user_id: UserId, ) -> Result> { - self.transact(|tx| async move { + self.transaction(|tx| async move { let pending_participant = room_participant::Entity::find() .filter( room_participant::Column::UserId .eq(user_id) .and(room_participant::Column::AnsweringConnectionId.is_null()), ) - .one(&tx) + .one(&*tx) .await?; if let Some(pending_participant) = pending_participant { @@ -1004,12 +991,12 @@ impl Database { connection_id: ConnectionId, live_kit_room: &str, ) -> Result> { - self.transact(|tx| async move { + self.room_transaction(|tx| async move { let room = room::ActiveModel { live_kit_room: ActiveValue::set(live_kit_room.into()), ..Default::default() } - .insert(&tx) + .insert(&*tx) .await?; let room_id = room.id; @@ -1023,11 +1010,11 @@ impl Database { calling_connection_epoch: ActiveValue::set(self.epoch), ..Default::default() } - .insert(&tx) + .insert(&*tx) .await?; let room = self.get_room(room_id, &tx).await?; - self.commit_room_transaction(room_id, tx, room).await + Ok((room_id, room)) }) .await } @@ -1040,7 +1027,7 @@ impl Database { called_user_id: UserId, initial_project_id: Option, ) -> Result> { - self.transact(|tx| async move { + self.room_transaction(|tx| async move { room_participant::ActiveModel { room_id: ActiveValue::set(room_id), user_id: ActiveValue::set(called_user_id), @@ -1050,14 +1037,13 @@ impl Database { initial_project_id: ActiveValue::set(initial_project_id), ..Default::default() } - .insert(&tx) + .insert(&*tx) .await?; let room = self.get_room(room_id, &tx).await?; let incoming_call = Self::build_incoming_call(&room, called_user_id) .ok_or_else(|| anyhow!("failed to build incoming call"))?; - self.commit_room_transaction(room_id, tx, (room, incoming_call)) - .await + Ok((room_id, (room, incoming_call))) }) .await } @@ -1067,17 +1053,17 @@ impl Database { room_id: RoomId, called_user_id: UserId, ) -> Result> { - self.transact(|tx| async move { + self.room_transaction(|tx| async move { room_participant::Entity::delete_many() .filter( room_participant::Column::RoomId .eq(room_id) .and(room_participant::Column::UserId.eq(called_user_id)), ) - .exec(&tx) + .exec(&*tx) .await?; let room = self.get_room(room_id, &tx).await?; - self.commit_room_transaction(room_id, tx, room).await + Ok((room_id, room)) }) .await } @@ -1087,14 +1073,14 @@ impl Database { expected_room_id: Option, user_id: UserId, ) -> Result> { - self.transact(|tx| async move { + self.room_transaction(|tx| async move { let participant = room_participant::Entity::find() .filter( room_participant::Column::UserId .eq(user_id) .and(room_participant::Column::AnsweringConnectionId.is_null()), ) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("could not decline call"))?; let room_id = participant.room_id; @@ -1104,11 +1090,11 @@ impl Database { } room_participant::Entity::delete(participant.into_active_model()) - .exec(&tx) + .exec(&*tx) .await?; let room = self.get_room(room_id, &tx).await?; - self.commit_room_transaction(room_id, tx, room).await + Ok((room_id, room)) }) .await } @@ -1119,7 +1105,7 @@ impl Database { calling_connection_id: ConnectionId, called_user_id: UserId, ) -> Result> { - self.transact(|tx| async move { + self.room_transaction(|tx| async move { let participant = room_participant::Entity::find() .filter( room_participant::Column::UserId @@ -1130,7 +1116,7 @@ impl Database { ) .and(room_participant::Column::AnsweringConnectionId.is_null()), ) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("could not cancel call"))?; let room_id = participant.room_id; @@ -1139,11 +1125,11 @@ impl Database { } room_participant::Entity::delete(participant.into_active_model()) - .exec(&tx) + .exec(&*tx) .await?; let room = self.get_room(room_id, &tx).await?; - self.commit_room_transaction(room_id, tx, room).await + Ok((room_id, room)) }) .await } @@ -1154,7 +1140,7 @@ impl Database { user_id: UserId, connection_id: ConnectionId, ) -> Result> { - self.transact(|tx| async move { + self.room_transaction(|tx| async move { let result = room_participant::Entity::update_many() .filter( room_participant::Column::RoomId @@ -1167,33 +1153,30 @@ impl Database { answering_connection_epoch: ActiveValue::set(Some(self.epoch)), ..Default::default() }) - .exec(&tx) + .exec(&*tx) .await?; if result.rows_affected == 0 { Err(anyhow!("room does not exist or was already joined"))? } else { let room = self.get_room(room_id, &tx).await?; - self.commit_room_transaction(room_id, tx, room).await + Ok((room_id, room)) } }) .await } - pub async fn leave_room( - &self, - connection_id: ConnectionId, - ) -> Result>> { - self.transact(|tx| async move { + pub async fn leave_room(&self, connection_id: ConnectionId) -> Result> { + self.room_transaction(|tx| async move { let leaving_participant = room_participant::Entity::find() .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0)) - .one(&tx) + .one(&*tx) .await?; if let Some(leaving_participant) = leaving_participant { // Leave room. let room_id = leaving_participant.room_id; room_participant::Entity::delete_by_id(leaving_participant.id) - .exec(&tx) + .exec(&*tx) .await?; // Cancel pending calls initiated by the leaving user. @@ -1203,14 +1186,14 @@ impl Database { .eq(connection_id.0) .and(room_participant::Column::AnsweringConnectionId.is_null()), ) - .all(&tx) + .all(&*tx) .await?; room_participant::Entity::delete_many() .filter( room_participant::Column::Id .is_in(called_participants.iter().map(|participant| participant.id)), ) - .exec(&tx) + .exec(&*tx) .await?; let canceled_calls_to_user_ids = called_participants .into_iter() @@ -1230,12 +1213,12 @@ impl Database { ) .filter(project_collaborator::Column::ConnectionId.eq(connection_id.0)) .into_values::<_, QueryProjectIds>() - .all(&tx) + .all(&*tx) .await?; let mut left_projects = HashMap::default(); let mut collaborators = project_collaborator::Entity::find() .filter(project_collaborator::Column::ProjectId.is_in(project_ids)) - .stream(&tx) + .stream(&*tx) .await?; while let Some(collaborator) = collaborators.next().await { let collaborator = collaborator?; @@ -1266,7 +1249,7 @@ impl Database { // Leave projects. project_collaborator::Entity::delete_many() .filter(project_collaborator::Column::ConnectionId.eq(connection_id.0)) - .exec(&tx) + .exec(&*tx) .await?; // Unshare projects. @@ -1276,33 +1259,27 @@ impl Database { .eq(room_id) .and(project::Column::HostConnectionId.eq(connection_id.0)), ) - .exec(&tx) + .exec(&*tx) .await?; let room = self.get_room(room_id, &tx).await?; if room.participants.is_empty() { - room::Entity::delete_by_id(room_id).exec(&tx).await?; + room::Entity::delete_by_id(room_id).exec(&*tx).await?; } - let left_room = self - .commit_room_transaction( - room_id, - tx, - LeftRoom { - room, - left_projects, - canceled_calls_to_user_ids, - }, - ) - .await?; + let left_room = LeftRoom { + room, + left_projects, + canceled_calls_to_user_ids, + }; if left_room.room.participants.is_empty() { self.rooms.remove(&room_id); } - Ok(Some(left_room)) + Ok((room_id, left_room)) } else { - Ok(None) + Err(anyhow!("could not leave room"))? } }) .await @@ -1314,8 +1291,8 @@ impl Database { connection_id: ConnectionId, location: proto::ParticipantLocation, ) -> Result> { - self.transact(|tx| async { - let mut tx = tx; + self.room_transaction(|tx| async { + let tx = tx; let location_kind; let location_project_id; match location @@ -1348,12 +1325,12 @@ impl Database { location_project_id: ActiveValue::set(location_project_id), ..Default::default() }) - .exec(&tx) + .exec(&*tx) .await?; if result.rows_affected == 1 { - let room = self.get_room(room_id, &mut tx).await?; - self.commit_room_transaction(room_id, tx, room).await + let room = self.get_room(room_id, &tx).await?; + Ok((room_id, room)) } else { Err(anyhow!("could not update room participant location"))? } @@ -1478,22 +1455,6 @@ impl Database { }) } - async fn commit_room_transaction( - &self, - room_id: RoomId, - tx: DatabaseTransaction, - data: T, - ) -> Result> { - let lock = self.rooms.entry(room_id).or_default().clone(); - let _guard = lock.lock_owned().await; - tx.commit().await?; - Ok(RoomGuard { - data, - _guard, - _not_send: PhantomData, - }) - } - // projects pub async fn project_count_excluding_admins(&self) -> Result { @@ -1502,14 +1463,14 @@ impl Database { Count, } - self.transact(|tx| async move { + self.transaction(|tx| async move { Ok(project::Entity::find() .select_only() .column_as(project::Column::Id.count(), QueryAs::Count) .inner_join(user::Entity) .filter(user::Column::Admin.eq(false)) .into_values::<_, QueryAs>() - .one(&tx) + .one(&*tx) .await? .unwrap_or(0) as usize) }) @@ -1522,10 +1483,10 @@ impl Database { connection_id: ConnectionId, worktrees: &[proto::WorktreeMetadata], ) -> Result> { - self.transact(|tx| async move { + self.room_transaction(|tx| async move { let participant = room_participant::Entity::find() .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0)) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("could not find participant"))?; if participant.room_id != room_id { @@ -1539,7 +1500,7 @@ impl Database { host_connection_epoch: ActiveValue::set(self.epoch), ..Default::default() } - .insert(&tx) + .insert(&*tx) .await?; if !worktrees.is_empty() { @@ -1554,7 +1515,7 @@ impl Database { is_complete: ActiveValue::set(false), } })) - .exec(&tx) + .exec(&*tx) .await?; } @@ -1567,12 +1528,11 @@ impl Database { is_host: ActiveValue::set(true), ..Default::default() } - .insert(&tx) + .insert(&*tx) .await?; let room = self.get_room(room_id, &tx).await?; - self.commit_room_transaction(room_id, tx, (project.id, room)) - .await + Ok((room_id, (project.id, room))) }) .await } @@ -1582,21 +1542,20 @@ impl Database { project_id: ProjectId, connection_id: ConnectionId, ) -> Result)>> { - self.transact(|tx| async move { + self.room_transaction(|tx| async move { let guest_connection_ids = self.project_guest_connection_ids(project_id, &tx).await?; let project = project::Entity::find_by_id(project_id) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("project not found"))?; if project.host_connection_id == connection_id.0 as i32 { let room_id = project.room_id; project::Entity::delete(project.into_active_model()) - .exec(&tx) + .exec(&*tx) .await?; let room = self.get_room(room_id, &tx).await?; - self.commit_room_transaction(room_id, tx, (room, guest_connection_ids)) - .await + Ok((room_id, (room, guest_connection_ids))) } else { Err(anyhow!("cannot unshare a project hosted by another user"))? } @@ -1610,10 +1569,10 @@ impl Database { connection_id: ConnectionId, worktrees: &[proto::WorktreeMetadata], ) -> Result)>> { - self.transact(|tx| async move { + self.room_transaction(|tx| async move { let project = project::Entity::find_by_id(project_id) .filter(project::Column::HostConnectionId.eq(connection_id.0)) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("no such project"))?; @@ -1634,7 +1593,7 @@ impl Database { .update_column(worktree::Column::RootName) .to_owned(), ) - .exec(&tx) + .exec(&*tx) .await?; } @@ -1645,13 +1604,12 @@ impl Database { .is_not_in(worktrees.iter().map(|worktree| worktree.id as i64)), ), ) - .exec(&tx) + .exec(&*tx) .await?; let guest_connection_ids = self.project_guest_connection_ids(project.id, &tx).await?; let room = self.get_room(project.room_id, &tx).await?; - self.commit_room_transaction(project.room_id, tx, (room, guest_connection_ids)) - .await + Ok((project.room_id, (room, guest_connection_ids))) }) .await } @@ -1661,14 +1619,14 @@ impl Database { update: &proto::UpdateWorktree, connection_id: ConnectionId, ) -> Result>> { - self.transact(|tx| async move { + self.room_transaction(|tx| async move { let project_id = ProjectId::from_proto(update.project_id); let worktree_id = update.worktree_id as i64; // Ensure the update comes from the host. let project = project::Entity::find_by_id(project_id) .filter(project::Column::HostConnectionId.eq(connection_id.0)) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("no such project"))?; let room_id = project.room_id; @@ -1683,7 +1641,7 @@ impl Database { abs_path: ActiveValue::set(update.abs_path.clone()), ..Default::default() }) - .exec(&tx) + .exec(&*tx) .await?; if !update.updated_entries.is_empty() { @@ -1719,7 +1677,7 @@ impl Database { ]) .to_owned(), ) - .exec(&tx) + .exec(&*tx) .await?; } @@ -1734,13 +1692,12 @@ impl Database { .is_in(update.removed_entries.iter().map(|id| *id as i64)), ), ) - .exec(&tx) + .exec(&*tx) .await?; } let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?; - self.commit_room_transaction(room_id, tx, connection_ids) - .await + Ok((room_id, connection_ids)) }) .await } @@ -1750,7 +1707,7 @@ impl Database { update: &proto::UpdateDiagnosticSummary, connection_id: ConnectionId, ) -> Result>> { - self.transact(|tx| async { + self.room_transaction(|tx| async move { let project_id = ProjectId::from_proto(update.project_id); let worktree_id = update.worktree_id as i64; let summary = update @@ -1760,7 +1717,7 @@ impl Database { // Ensure the update comes from the host. let project = project::Entity::find_by_id(project_id) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("no such project"))?; if project.host_connection_id != connection_id.0 as i32 { @@ -1790,12 +1747,11 @@ impl Database { ]) .to_owned(), ) - .exec(&tx) + .exec(&*tx) .await?; let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?; - self.commit_room_transaction(project.room_id, tx, connection_ids) - .await + Ok((project.room_id, connection_ids)) }) .await } @@ -1805,7 +1761,7 @@ impl Database { update: &proto::StartLanguageServer, connection_id: ConnectionId, ) -> Result>> { - self.transact(|tx| async { + self.room_transaction(|tx| async move { let project_id = ProjectId::from_proto(update.project_id); let server = update .server @@ -1814,7 +1770,7 @@ impl Database { // Ensure the update comes from the host. let project = project::Entity::find_by_id(project_id) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("no such project"))?; if project.host_connection_id != connection_id.0 as i32 { @@ -1836,12 +1792,11 @@ impl Database { .update_column(language_server::Column::Name) .to_owned(), ) - .exec(&tx) + .exec(&*tx) .await?; let connection_ids = self.project_guest_connection_ids(project_id, &tx).await?; - self.commit_room_transaction(project.room_id, tx, connection_ids) - .await + Ok((project.room_id, connection_ids)) }) .await } @@ -1851,15 +1806,15 @@ impl Database { project_id: ProjectId, connection_id: ConnectionId, ) -> Result> { - self.transact(|tx| async move { + self.room_transaction(|tx| async move { let participant = room_participant::Entity::find() .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.0)) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("must join a room first"))?; let project = project::Entity::find_by_id(project_id) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("no such project"))?; if project.room_id != participant.room_id { @@ -1868,7 +1823,7 @@ impl Database { let mut collaborators = project .find_related(project_collaborator::Entity) - .all(&tx) + .all(&*tx) .await?; let replica_ids = collaborators .iter() @@ -1887,11 +1842,11 @@ impl Database { is_host: ActiveValue::set(false), ..Default::default() } - .insert(&tx) + .insert(&*tx) .await?; collaborators.push(new_collaborator); - let db_worktrees = project.find_related(worktree::Entity).all(&tx).await?; + let db_worktrees = project.find_related(worktree::Entity).all(&*tx).await?; let mut worktrees = db_worktrees .into_iter() .map(|db_worktree| { @@ -1915,7 +1870,7 @@ impl Database { { let mut db_entries = worktree_entry::Entity::find() .filter(worktree_entry::Column::ProjectId.eq(project_id)) - .stream(&tx) + .stream(&*tx) .await?; while let Some(db_entry) = db_entries.next().await { let db_entry = db_entry?; @@ -1940,7 +1895,7 @@ impl Database { { let mut db_summaries = worktree_diagnostic_summary::Entity::find() .filter(worktree_diagnostic_summary::Column::ProjectId.eq(project_id)) - .stream(&tx) + .stream(&*tx) .await?; while let Some(db_summary) = db_summaries.next().await { let db_summary = db_summary?; @@ -1960,28 +1915,22 @@ impl Database { // Populate language servers. let language_servers = project .find_related(language_server::Entity) - .all(&tx) + .all(&*tx) .await?; - self.commit_room_transaction( - project.room_id, - tx, - ( - Project { - collaborators, - worktrees, - language_servers: language_servers - .into_iter() - .map(|language_server| proto::LanguageServer { - id: language_server.id as u64, - name: language_server.name, - }) - .collect(), - }, - replica_id as ReplicaId, - ), - ) - .await + let room_id = project.room_id; + let project = Project { + collaborators, + worktrees, + language_servers: language_servers + .into_iter() + .map(|language_server| proto::LanguageServer { + id: language_server.id as u64, + name: language_server.name, + }) + .collect(), + }; + Ok((room_id, (project, replica_id as ReplicaId))) }) .await } @@ -1991,43 +1940,39 @@ impl Database { project_id: ProjectId, connection_id: ConnectionId, ) -> Result> { - self.transact(|tx| async move { + self.room_transaction(|tx| async move { let result = project_collaborator::Entity::delete_many() .filter( project_collaborator::Column::ProjectId .eq(project_id) .and(project_collaborator::Column::ConnectionId.eq(connection_id.0)), ) - .exec(&tx) + .exec(&*tx) .await?; if result.rows_affected == 0 { Err(anyhow!("not a collaborator on this project"))?; } let project = project::Entity::find_by_id(project_id) - .one(&tx) + .one(&*tx) .await? .ok_or_else(|| anyhow!("no such project"))?; let collaborators = project .find_related(project_collaborator::Entity) - .all(&tx) + .all(&*tx) .await?; let connection_ids = collaborators .into_iter() .map(|collaborator| ConnectionId(collaborator.connection_id as u32)) .collect(); - self.commit_room_transaction( - project.room_id, - tx, - LeftProject { - id: project_id, - host_user_id: project.host_user_id, - host_connection_id: ConnectionId(project.host_connection_id as u32), - connection_ids, - }, - ) - .await + let left_project = LeftProject { + id: project_id, + host_user_id: project.host_user_id, + host_connection_id: ConnectionId(project.host_connection_id as u32), + connection_ids, + }; + Ok((project.room_id, left_project)) }) .await } @@ -2037,10 +1982,10 @@ impl Database { project_id: ProjectId, connection_id: ConnectionId, ) -> Result> { - self.transact(|tx| async move { + self.transaction(|tx| async move { let collaborators = project_collaborator::Entity::find() .filter(project_collaborator::Column::ProjectId.eq(project_id)) - .all(&tx) + .all(&*tx) .await?; if collaborators @@ -2060,7 +2005,7 @@ impl Database { project_id: ProjectId, connection_id: ConnectionId, ) -> Result> { - self.transact(|tx| async move { + self.transaction(|tx| async move { #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] enum QueryAs { ConnectionId, @@ -2074,7 +2019,7 @@ impl Database { ) .filter(project_collaborator::Column::ProjectId.eq(project_id)) .into_values::() - .stream(&tx) + .stream(&*tx) .await?; let mut connection_ids = HashSet::default(); @@ -2131,7 +2076,7 @@ impl Database { access_token_hash: &str, max_access_token_count: usize, ) -> Result<()> { - self.transact(|tx| async { + self.transaction(|tx| async { let tx = tx; access_token::ActiveModel { @@ -2139,7 +2084,7 @@ impl Database { hash: ActiveValue::set(access_token_hash.into()), ..Default::default() } - .insert(&tx) + .insert(&*tx) .await?; access_token::Entity::delete_many() @@ -2155,9 +2100,8 @@ impl Database { .to_owned(), ), ) - .exec(&tx) + .exec(&*tx) .await?; - tx.commit().await?; Ok(()) }) .await @@ -2169,22 +2113,22 @@ impl Database { Hash, } - self.transact(|tx| async move { + self.transaction(|tx| async move { Ok(access_token::Entity::find() .select_only() .column(access_token::Column::Hash) .filter(access_token::Column::UserId.eq(user_id)) .order_by_desc(access_token::Column::Id) .into_values::<_, QueryAs>() - .all(&tx) + .all(&*tx) .await?) }) .await } - async fn transact(&self, f: F) -> Result + async fn transaction(&self, f: F) -> Result where - F: Send + Fn(DatabaseTransaction) -> Fut, + F: Send + Fn(TransactionHandle) -> Fut, Fut: Send + Future>, { let body = async { @@ -2200,22 +2144,101 @@ impl Database { .await?; } - match f(tx).await { - Ok(result) => return Ok(result), - Err(error) => match error { - Error::Database( - DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error)) - | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)), - ) if error - .as_database_error() - .and_then(|error| error.code()) - .as_deref() - == Some("40001") => - { - // Retry (don't break the loop) + let mut tx = Arc::new(Some(tx)); + let result = f(TransactionHandle(tx.clone())).await; + let tx = Arc::get_mut(&mut tx).unwrap().take().unwrap(); + + match result { + Ok(result) => { + tx.commit().await?; + return Ok(result); + } + Err(error) => { + tx.rollback().await?; + match error { + Error::Database( + DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error)) + | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)), + ) if error + .as_database_error() + .and_then(|error| error.code()) + .as_deref() + == Some("40001") => + { + // Retry (don't break the loop) + } + error @ _ => return Err(error), } - error @ _ => return Err(error), - }, + } + } + } + }; + + #[cfg(test)] + { + if let Some(background) = self.background.as_ref() { + background.simulate_random_delay().await; + } + + self.runtime.as_ref().unwrap().block_on(body) + } + + #[cfg(not(test))] + { + body.await + } + } + + async fn room_transaction(&self, f: F) -> Result> + where + F: Send + Fn(TransactionHandle) -> Fut, + Fut: Send + Future>, + { + let body = async { + loop { + let tx = self.pool.begin().await?; + + // In Postgres, serializable transactions are opt-in + if let DatabaseBackend::Postgres = self.pool.get_database_backend() { + tx.execute(Statement::from_string( + DatabaseBackend::Postgres, + "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(), + )) + .await?; + } + + let mut tx = Arc::new(Some(tx)); + let result = f(TransactionHandle(tx.clone())).await; + let tx = Arc::get_mut(&mut tx).unwrap().take().unwrap(); + + match result { + Ok((room_id, data)) => { + let lock = self.rooms.entry(room_id).or_default().clone(); + let _guard = lock.lock_owned().await; + tx.commit().await?; + return Ok(RoomGuard { + data, + _guard, + _not_send: PhantomData, + }); + } + Err(error) => { + tx.rollback().await?; + match error { + Error::Database( + DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error)) + | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)), + ) if error + .as_database_error() + .and_then(|error| error.code()) + .as_deref() + == Some("40001") => + { + // Retry (don't break the loop) + } + error @ _ => return Err(error), + } + } } } }; @@ -2236,6 +2259,16 @@ impl Database { } } +struct TransactionHandle(Arc>); + +impl Deref for TransactionHandle { + type Target = DatabaseTransaction; + + fn deref(&self) -> &Self::Target { + self.0.as_ref().as_ref().unwrap() + } +} + pub struct RoomGuard { data: T, _guard: OwnedMutexGuard<()>, diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 9d3917a417ef4bde4f20b09771ab11fbdc26acfd..7f404feffe04d02329b58be660e226cf4d3fe008 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1854,9 +1854,7 @@ async fn leave_room_for_session(session: &Session) -> Result<()> { let live_kit_room; let delete_live_kit_room; { - let Some(mut left_room) = session.db().await.leave_room(session.connection_id).await? else { - return Err(anyhow!("no room to leave"))?; - }; + let mut left_room = session.db().await.leave_room(session.connection_id).await?; contacts_to_update.insert(session.user_id); for project in left_room.left_projects.values() { From 0ed731780a113934f37f9ab0a5f428dd288692b0 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 5 Dec 2022 09:46:03 +0100 Subject: [PATCH 100/109] Remove duplication between `transaction` and `room_transaction` --- crates/collab/src/db.rs | 57 +++++++++++++++++++---------------------- 1 file changed, 27 insertions(+), 30 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index e667930cad2953d1379f9aa07389202f16ff2219..3066260bc431f65f18a68bbc7bd68442c18e0078 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -2133,21 +2133,7 @@ impl Database { { let body = async { loop { - let tx = self.pool.begin().await?; - - // In Postgres, serializable transactions are opt-in - if let DatabaseBackend::Postgres = self.pool.get_database_backend() { - tx.execute(Statement::from_string( - DatabaseBackend::Postgres, - "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(), - )) - .await?; - } - - let mut tx = Arc::new(Some(tx)); - let result = f(TransactionHandle(tx.clone())).await; - let tx = Arc::get_mut(&mut tx).unwrap().take().unwrap(); - + let (tx, result) = self.with_transaction(&f).await?; match result { Ok(result) => { tx.commit().await?; @@ -2196,21 +2182,7 @@ impl Database { { let body = async { loop { - let tx = self.pool.begin().await?; - - // In Postgres, serializable transactions are opt-in - if let DatabaseBackend::Postgres = self.pool.get_database_backend() { - tx.execute(Statement::from_string( - DatabaseBackend::Postgres, - "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(), - )) - .await?; - } - - let mut tx = Arc::new(Some(tx)); - let result = f(TransactionHandle(tx.clone())).await; - let tx = Arc::get_mut(&mut tx).unwrap().take().unwrap(); - + let (tx, result) = self.with_transaction(&f).await?; match result { Ok((room_id, data)) => { let lock = self.rooms.entry(room_id).or_default().clone(); @@ -2257,6 +2229,31 @@ impl Database { body.await } } + + async fn with_transaction(&self, f: &F) -> Result<(DatabaseTransaction, Result)> + where + F: Send + Fn(TransactionHandle) -> Fut, + Fut: Send + Future>, + { + let tx = self.pool.begin().await?; + + // In Postgres, serializable transactions are opt-in + if let DatabaseBackend::Postgres = self.pool.get_database_backend() { + tx.execute(Statement::from_string( + DatabaseBackend::Postgres, + "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(), + )) + .await?; + } + + let mut tx = Arc::new(Some(tx)); + let result = f(TransactionHandle(tx.clone())).await; + let Some(tx) = Arc::get_mut(&mut tx).and_then(|tx| tx.take()) else { + return Err(anyhow!("couldn't complete transaction because it's still in use"))?; + }; + + Ok((tx, result)) + } } struct TransactionHandle(Arc>); From d97a8364adc2340ff4388ad21333ef52961e4426 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 5 Dec 2022 10:49:53 +0100 Subject: [PATCH 101/109] Retry transactions if there's a serialization failure during commit --- crates/collab/src/db.rs | 163 +++++++++++++++++++++------------------- 1 file changed, 87 insertions(+), 76 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 3066260bc431f65f18a68bbc7bd68442c18e0078..bc074e30df5ac6bc4d80fe62e42ee6cd78ed6387 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -2131,47 +2131,30 @@ impl Database { F: Send + Fn(TransactionHandle) -> Fut, Fut: Send + Future>, { - let body = async { - loop { - let (tx, result) = self.with_transaction(&f).await?; - match result { - Ok(result) => { - tx.commit().await?; - return Ok(result); - } - Err(error) => { - tx.rollback().await?; - match error { - Error::Database( - DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error)) - | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)), - ) if error - .as_database_error() - .and_then(|error| error.code()) - .as_deref() - == Some("40001") => - { + loop { + let (tx, result) = self.run(self.with_transaction(&f)).await?; + match result { + Ok(result) => { + match self.run(async move { Ok(tx.commit().await?) }).await { + Ok(()) => return Ok(result), + Err(error) => { + if is_serialization_error(&error) { // Retry (don't break the loop) + } else { + return Err(error); } - error @ _ => return Err(error), } } } + Err(error) => { + self.run(tx.rollback()).await?; + if is_serialization_error(&error) { + // Retry (don't break the loop) + } else { + return Err(error); + } + } } - }; - - #[cfg(test)] - { - if let Some(background) = self.background.as_ref() { - background.simulate_random_delay().await; - } - - self.runtime.as_ref().unwrap().block_on(body) - } - - #[cfg(not(test))] - { - body.await } } @@ -2180,53 +2163,38 @@ impl Database { F: Send + Fn(TransactionHandle) -> Fut, Fut: Send + Future>, { - let body = async { - loop { - let (tx, result) = self.with_transaction(&f).await?; - match result { - Ok((room_id, data)) => { - let lock = self.rooms.entry(room_id).or_default().clone(); - let _guard = lock.lock_owned().await; - tx.commit().await?; - return Ok(RoomGuard { - data, - _guard, - _not_send: PhantomData, - }); - } - Err(error) => { - tx.rollback().await?; - match error { - Error::Database( - DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error)) - | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)), - ) if error - .as_database_error() - .and_then(|error| error.code()) - .as_deref() - == Some("40001") => - { + loop { + let (tx, result) = self.run(self.with_transaction(&f)).await?; + match result { + Ok((room_id, data)) => { + let lock = self.rooms.entry(room_id).or_default().clone(); + let _guard = lock.lock_owned().await; + match self.run(async move { Ok(tx.commit().await?) }).await { + Ok(()) => { + return Ok(RoomGuard { + data, + _guard, + _not_send: PhantomData, + }); + } + Err(error) => { + if is_serialization_error(&error) { // Retry (don't break the loop) + } else { + return Err(error); } - error @ _ => return Err(error), } } } + Err(error) => { + self.run(tx.rollback()).await?; + if is_serialization_error(&error) { + // Retry (don't break the loop) + } else { + return Err(error); + } + } } - }; - - #[cfg(test)] - { - if let Some(background) = self.background.as_ref() { - background.simulate_random_delay().await; - } - - self.runtime.as_ref().unwrap().block_on(body) - } - - #[cfg(not(test))] - { - body.await } } @@ -2254,6 +2222,49 @@ impl Database { Ok((tx, result)) } + + async fn run(&self, future: F) -> T + where + F: Future, + { + #[cfg(test)] + { + if let Some(background) = self.background.as_ref() { + background.simulate_random_delay().await; + } + + let result = self.runtime.as_ref().unwrap().block_on(future); + + if let Some(background) = self.background.as_ref() { + background.simulate_random_delay().await; + } + + result + } + + #[cfg(not(test))] + { + future.await + } + } +} + +fn is_serialization_error(error: &Error) -> bool { + const SERIALIZATION_FAILURE_CODE: &'static str = "40001"; + match error { + Error::Database( + DbErr::Exec(sea_orm::RuntimeErr::SqlxError(error)) + | DbErr::Query(sea_orm::RuntimeErr::SqlxError(error)), + ) if error + .as_database_error() + .and_then(|error| error.code()) + .as_deref() + == Some(SERIALIZATION_FAILURE_CODE) => + { + true + } + _ => false, + } } struct TransactionHandle(Arc>); From d3c411677ababde3c562c005def58978eb6a944c Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 5 Dec 2022 12:03:45 +0100 Subject: [PATCH 102/109] Remove random pauses to prevent the database from deadlocking --- crates/collab/src/db.rs | 108 +++++++++++++------------ crates/collab/src/integration_tests.rs | 8 +- 2 files changed, 62 insertions(+), 54 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index bc074e30df5ac6bc4d80fe62e42ee6cd78ed6387..dfd1d7e65a1d1467aac38d8694d72fd981a6c1da 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -2131,31 +2131,35 @@ impl Database { F: Send + Fn(TransactionHandle) -> Fut, Fut: Send + Future>, { - loop { - let (tx, result) = self.run(self.with_transaction(&f)).await?; - match result { - Ok(result) => { - match self.run(async move { Ok(tx.commit().await?) }).await { - Ok(()) => return Ok(result), - Err(error) => { - if is_serialization_error(&error) { - // Retry (don't break the loop) - } else { - return Err(error); + let body = async { + loop { + let (tx, result) = self.with_transaction(&f).await?; + match result { + Ok(result) => { + match tx.commit().await.map_err(Into::into) { + Ok(()) => return Ok(result), + Err(error) => { + if is_serialization_error(&error) { + // Retry (don't break the loop) + } else { + return Err(error); + } } } } - } - Err(error) => { - self.run(tx.rollback()).await?; - if is_serialization_error(&error) { - // Retry (don't break the loop) - } else { - return Err(error); + Err(error) => { + tx.rollback().await?; + if is_serialization_error(&error) { + // Retry (don't break the loop) + } else { + return Err(error); + } } } } - } + }; + + self.run(body).await } async fn room_transaction(&self, f: F) -> Result> @@ -2163,39 +2167,43 @@ impl Database { F: Send + Fn(TransactionHandle) -> Fut, Fut: Send + Future>, { - loop { - let (tx, result) = self.run(self.with_transaction(&f)).await?; - match result { - Ok((room_id, data)) => { - let lock = self.rooms.entry(room_id).or_default().clone(); - let _guard = lock.lock_owned().await; - match self.run(async move { Ok(tx.commit().await?) }).await { - Ok(()) => { - return Ok(RoomGuard { - data, - _guard, - _not_send: PhantomData, - }); - } - Err(error) => { - if is_serialization_error(&error) { - // Retry (don't break the loop) - } else { - return Err(error); + let body = async { + loop { + let (tx, result) = self.with_transaction(&f).await?; + match result { + Ok((room_id, data)) => { + let lock = self.rooms.entry(room_id).or_default().clone(); + let _guard = lock.lock_owned().await; + match tx.commit().await.map_err(Into::into) { + Ok(()) => { + return Ok(RoomGuard { + data, + _guard, + _not_send: PhantomData, + }); + } + Err(error) => { + if is_serialization_error(&error) { + // Retry (don't break the loop) + } else { + return Err(error); + } } } } - } - Err(error) => { - self.run(tx.rollback()).await?; - if is_serialization_error(&error) { - // Retry (don't break the loop) - } else { - return Err(error); + Err(error) => { + tx.rollback().await?; + if is_serialization_error(&error) { + // Retry (don't break the loop) + } else { + return Err(error); + } } } } - } + }; + + self.run(body).await } async fn with_transaction(&self, f: &F) -> Result<(DatabaseTransaction, Result)> @@ -2233,13 +2241,7 @@ impl Database { background.simulate_random_delay().await; } - let result = self.runtime.as_ref().unwrap().block_on(future); - - if let Some(background) = self.background.as_ref() { - background.simulate_random_delay().await; - } - - result + self.runtime.as_ref().unwrap().block_on(future) } #[cfg(not(test))] diff --git a/crates/collab/src/integration_tests.rs b/crates/collab/src/integration_tests.rs index 73f450b8336c092757e3ca872ce763ab0a405558..4ff372efbe95d4a80d646ddabd87bc9f6267378b 100644 --- a/crates/collab/src/integration_tests.rs +++ b/crates/collab/src/integration_tests.rs @@ -5672,7 +5672,13 @@ impl TestServer { async fn start(background: Arc) -> Self { static NEXT_LIVE_KIT_SERVER_ID: AtomicUsize = AtomicUsize::new(0); - let test_db = TestDb::sqlite(background.clone()); + let use_postgres = env::var("USE_POSTGRES").ok(); + let use_postgres = use_postgres.as_deref(); + let test_db = if use_postgres == Some("true") || use_postgres == Some("1") { + TestDb::postgres(background.clone()) + } else { + TestDb::sqlite(background.clone()) + }; let live_kit_server_id = NEXT_LIVE_KIT_SERVER_ID.fetch_add(1, SeqCst); let live_kit_server = live_kit_client::TestServer::create( format!("http://livekit.{}.test", live_kit_server_id), From eec3df09be3825e730b9357b061a9a525f385cb6 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 5 Dec 2022 14:56:01 +0100 Subject: [PATCH 103/109] Upgrade sea-orm --- Cargo.lock | 12 ++++++------ crates/collab/src/db.rs | 20 +++++++------------- 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d1b8a488f2abe489921e72e81a76e207780d6b80..a75ca972e210aaa85daefffb93bf290c215761c4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3204,9 +3204,9 @@ dependencies = [ [[package]] name = "libsqlite3-sys" -version = "0.25.1" +version = "0.25.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f0455f2c1bc9a7caa792907026e469c1d91761fb0ea37cbb16427c77280cf35" +checksum = "29f835d03d717946d28b1d1ed632eb6f0e24a299388ee623d0c23118d3e8a7fa" dependencies = [ "cc", "pkg-config", @@ -5328,9 +5328,9 @@ dependencies = [ [[package]] name = "sea-orm" -version = "0.10.4" +version = "0.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3120bc435b8640963ffda698f877610e07e077157e216eb99408d819c344034d" +checksum = "28fc9dad132e450d6320bd5953e70fb88b42785080b591e9be804da69bd8a170" dependencies = [ "async-stream", "async-trait", @@ -5356,9 +5356,9 @@ dependencies = [ [[package]] name = "sea-orm-macros" -version = "0.10.4" +version = "0.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c54bacfeb842813c16821e21f9456c358861a448294075184ea1d6307e386d08" +checksum = "66af5d33e04e56dafb2c700f9b1201a39e6c2c77b53ed9ee93244f21f8de6041" dependencies = [ "bae", "heck 0.3.3", diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index dfd1d7e65a1d1467aac38d8694d72fd981a6c1da..8250a8354fa9eae5fd2aa195e10334b84b15f511 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -23,9 +23,9 @@ use hyper::StatusCode; use rpc::{proto, ConnectionId}; pub use sea_orm::ConnectOptions; use sea_orm::{ - entity::prelude::*, ActiveValue, ConnectionTrait, DatabaseBackend, DatabaseConnection, - DatabaseTransaction, DbErr, FromQueryResult, IntoActiveModel, JoinType, QueryOrder, - QuerySelect, Statement, TransactionTrait, + entity::prelude::*, ActiveValue, ConnectionTrait, DatabaseConnection, DatabaseTransaction, + DbErr, FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, + Statement, TransactionTrait, }; use sea_query::{Alias, Expr, OnConflict, Query}; use serde::{Deserialize, Serialize}; @@ -2211,16 +2211,10 @@ impl Database { F: Send + Fn(TransactionHandle) -> Fut, Fut: Send + Future>, { - let tx = self.pool.begin().await?; - - // In Postgres, serializable transactions are opt-in - if let DatabaseBackend::Postgres = self.pool.get_database_backend() { - tx.execute(Statement::from_string( - DatabaseBackend::Postgres, - "SET TRANSACTION ISOLATION LEVEL SERIALIZABLE;".into(), - )) + let tx = self + .pool + .begin_with_config(Some(IsolationLevel::Serializable), None) .await?; - } let mut tx = Arc::new(Some(tx)); let result = f(TransactionHandle(tx.clone())).await; @@ -2584,7 +2578,7 @@ mod test { impl Drop for TestDb { fn drop(&mut self) { let db = self.db.take().unwrap(); - if let DatabaseBackend::Postgres = db.pool.get_database_backend() { + if let sea_orm::DatabaseBackend::Postgres = db.pool.get_database_backend() { db.runtime.as_ref().unwrap().block_on(async { use util::ResultExt; let query = " From b97c35a4686f27054e4df92616d79afd86c15e21 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 5 Dec 2022 15:16:06 +0100 Subject: [PATCH 104/109] Remove project_id foreign key from `room_participants` --- .../collab/migrations.sqlite/20221109000000_test_schema.sql | 4 ++-- .../collab/migrations/20221111092550_reconnection_support.sql | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 347db6a71a8d44f21f5cfcac7c3c73a1c67856c9..90fd8ace122ff0a6e28b879634b574e6876951a0 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -118,8 +118,8 @@ CREATE TABLE "room_participants" ( "answering_connection_id" INTEGER, "answering_connection_epoch" TEXT, "location_kind" INTEGER, - "location_project_id" INTEGER REFERENCES projects (id), - "initial_project_id" INTEGER REFERENCES projects (id), + "location_project_id" INTEGER, + "initial_project_id" INTEGER, "calling_user_id" INTEGER NOT NULL REFERENCES users (id), "calling_connection_id" INTEGER NOT NULL, "calling_connection_epoch" TEXT NOT NULL diff --git a/crates/collab/migrations/20221111092550_reconnection_support.sql b/crates/collab/migrations/20221111092550_reconnection_support.sql index 6278fa7a595b05cb7adbf97f622b06f675116af3..5e8bada2f9492b91212108e0eae1b0b99d53b63a 100644 --- a/crates/collab/migrations/20221111092550_reconnection_support.sql +++ b/crates/collab/migrations/20221111092550_reconnection_support.sql @@ -80,8 +80,8 @@ CREATE TABLE "room_participants" ( "answering_connection_id" INTEGER, "answering_connection_epoch" UUID, "location_kind" INTEGER, - "location_project_id" INTEGER REFERENCES projects (id), - "initial_project_id" INTEGER REFERENCES projects (id), + "location_project_id" INTEGER, + "initial_project_id" INTEGER, "calling_user_id" INTEGER NOT NULL REFERENCES users (id), "calling_connection_id" INTEGER NOT NULL, "calling_connection_epoch" UUID NOT NULL From be3fb1e9856e11416963716f367ddfda1ca44163 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 5 Dec 2022 17:57:10 +0100 Subject: [PATCH 105/109] Update sea-orm to fix bug on failure to commit transactions Co-Authored-By: Nathan Sobo --- Cargo.lock | 6 ++---- crates/collab/Cargo.toml | 5 +++-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a75ca972e210aaa85daefffb93bf290c215761c4..30c5054576a4ef8caf0f938fb685cfb6dda4860b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5329,8 +5329,7 @@ dependencies = [ [[package]] name = "sea-orm" version = "0.10.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "28fc9dad132e450d6320bd5953e70fb88b42785080b591e9be804da69bd8a170" +source = "git+https://github.com/zed-industries/sea-orm?rev=18f4c691085712ad014a51792af75a9044bacee6#18f4c691085712ad014a51792af75a9044bacee6" dependencies = [ "async-stream", "async-trait", @@ -5357,8 +5356,7 @@ dependencies = [ [[package]] name = "sea-orm-macros" version = "0.10.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66af5d33e04e56dafb2c700f9b1201a39e6c2c77b53ed9ee93244f21f8de6041" +source = "git+https://github.com/zed-industries/sea-orm?rev=18f4c691085712ad014a51792af75a9044bacee6#18f4c691085712ad014a51792af75a9044bacee6" dependencies = [ "bae", "heck 0.3.3", diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index 2238be2257335d09fab631b9929dc79ef7566c2d..8725642ae52a4244234dad2c364ebc3294673dce 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -36,7 +36,8 @@ prometheus = "0.13" rand = "0.8" reqwest = { version = "0.11", features = ["json"], optional = true } scrypt = "0.7" -sea-orm = { version = "0.10", features = ["sqlx-postgres", "postgres-array", "runtime-tokio-rustls"] } +# Remove fork dependency when a version with https://github.com/SeaQL/sea-orm/pull/1283 is released. +sea-orm = { git = "https://github.com/zed-industries/sea-orm", rev = "18f4c691085712ad014a51792af75a9044bacee6", features = ["sqlx-postgres", "postgres-array", "runtime-tokio-rustls"] } sea-query = "0.27" serde = { version = "1.0", features = ["derive", "rc"] } serde_json = "1.0" @@ -74,7 +75,7 @@ env_logger = "0.9" log = { version = "0.4.16", features = ["kv_unstable_serde"] } util = { path = "../util" } lazy_static = "1.4" -sea-orm = { version = "0.10", features = ["sqlx-sqlite"] } +sea-orm = { git = "https://github.com/zed-industries/sea-orm", rev = "18f4c691085712ad014a51792af75a9044bacee6", features = ["sqlx-sqlite"] } serde_json = { version = "1.0", features = ["preserve_order"] } sqlx = { version = "0.6", features = ["sqlite"] } unindent = "0.1" From 5443d9cffe17a8faa1299d4852cd4d4c2ff4aa8c Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 5 Dec 2022 18:37:01 +0100 Subject: [PATCH 106/109] Return project collaborators and connection IDs in a `RoomGuard` --- crates/collab/src/db.rs | 20 +++++++--- crates/collab/src/rpc.rs | 81 +++++++++++++++++++++------------------- 2 files changed, 57 insertions(+), 44 deletions(-) diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 8250a8354fa9eae5fd2aa195e10334b84b15f511..915acb00eb504f792c3dcd1bef873c6db546dae1 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1981,8 +1981,12 @@ impl Database { &self, project_id: ProjectId, connection_id: ConnectionId, - ) -> Result> { - self.transaction(|tx| async move { + ) -> Result>> { + self.room_transaction(|tx| async move { + let project = project::Entity::find_by_id(project_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; let collaborators = project_collaborator::Entity::find() .filter(project_collaborator::Column::ProjectId.eq(project_id)) .all(&*tx) @@ -1992,7 +1996,7 @@ impl Database { .iter() .any(|collaborator| collaborator.connection_id == connection_id.0 as i32) { - Ok(collaborators) + Ok((project.room_id, collaborators)) } else { Err(anyhow!("no such project"))? } @@ -2004,13 +2008,17 @@ impl Database { &self, project_id: ProjectId, connection_id: ConnectionId, - ) -> Result> { - self.transaction(|tx| async move { + ) -> Result>> { + self.room_transaction(|tx| async move { #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] enum QueryAs { ConnectionId, } + let project = project::Entity::find_by_id(project_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such project"))?; let mut db_connection_ids = project_collaborator::Entity::find() .select_only() .column_as( @@ -2028,7 +2036,7 @@ impl Database { } if connection_ids.contains(&connection_id) { - Ok(connection_ids) + Ok((project.room_id, connection_ids)) } else { Err(anyhow!("no such project"))? } diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 7f404feffe04d02329b58be660e226cf4d3fe008..79544de6fbdcc959f82c79c7cd830336cc6e2696 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -1245,7 +1245,7 @@ async fn update_language_server( .await?; broadcast( session.connection_id, - project_connection_ids, + project_connection_ids.iter().copied(), |connection_id| { session .peer @@ -1264,23 +1264,24 @@ where T: EntityMessage + RequestMessage, { let project_id = ProjectId::from_proto(request.remote_entity_id()); - let collaborators = session - .db() - .await - .project_collaborators(project_id, session.connection_id) - .await?; - let host = collaborators - .iter() - .find(|collaborator| collaborator.is_host) - .ok_or_else(|| anyhow!("host not found"))?; + let host_connection_id = { + let collaborators = session + .db() + .await + .project_collaborators(project_id, session.connection_id) + .await?; + ConnectionId( + collaborators + .iter() + .find(|collaborator| collaborator.is_host) + .ok_or_else(|| anyhow!("host not found"))? + .connection_id as u32, + ) + }; let payload = session .peer - .forward_request( - session.connection_id, - ConnectionId(host.connection_id as u32), - request, - ) + .forward_request(session.connection_id, host_connection_id, request) .await?; response.send(payload)?; @@ -1293,16 +1294,18 @@ async fn save_buffer( session: Session, ) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); - let collaborators = session - .db() - .await - .project_collaborators(project_id, session.connection_id) - .await?; - let host = collaborators - .into_iter() - .find(|collaborator| collaborator.is_host) - .ok_or_else(|| anyhow!("host not found"))?; - let host_connection_id = ConnectionId(host.connection_id as u32); + let host_connection_id = { + let collaborators = session + .db() + .await + .project_collaborators(project_id, session.connection_id) + .await?; + let host = collaborators + .iter() + .find(|collaborator| collaborator.is_host) + .ok_or_else(|| anyhow!("host not found"))?; + ConnectionId(host.connection_id as u32) + }; let response_payload = session .peer .forward_request(session.connection_id, host_connection_id, request.clone()) @@ -1316,7 +1319,7 @@ async fn save_buffer( collaborators .retain(|collaborator| collaborator.connection_id != session.connection_id.0 as i32); let project_connection_ids = collaborators - .into_iter() + .iter() .map(|collaborator| ConnectionId(collaborator.connection_id as u32)); broadcast(host_connection_id, project_connection_ids, |conn_id| { session @@ -1353,7 +1356,7 @@ async fn update_buffer( broadcast( session.connection_id, - project_connection_ids, + project_connection_ids.iter().copied(), |connection_id| { session .peer @@ -1374,7 +1377,7 @@ async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) broadcast( session.connection_id, - project_connection_ids, + project_connection_ids.iter().copied(), |connection_id| { session .peer @@ -1393,7 +1396,7 @@ async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Re .await?; broadcast( session.connection_id, - project_connection_ids, + project_connection_ids.iter().copied(), |connection_id| { session .peer @@ -1412,7 +1415,7 @@ async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<( .await?; broadcast( session.connection_id, - project_connection_ids, + project_connection_ids.iter().copied(), |connection_id| { session .peer @@ -1430,14 +1433,16 @@ async fn follow( let project_id = ProjectId::from_proto(request.project_id); let leader_id = ConnectionId(request.leader_id); let follower_id = session.connection_id; - let project_connection_ids = session - .db() - .await - .project_connection_ids(project_id, session.connection_id) - .await?; + { + let project_connection_ids = session + .db() + .await + .project_connection_ids(project_id, session.connection_id) + .await?; - if !project_connection_ids.contains(&leader_id) { - Err(anyhow!("no such peer"))?; + if !project_connection_ids.contains(&leader_id) { + Err(anyhow!("no such peer"))?; + } } let mut response_payload = session @@ -1691,7 +1696,7 @@ async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> R .await?; broadcast( session.connection_id, - project_connection_ids, + project_connection_ids.iter().copied(), |connection_id| { session .peer From 7bbd97cfb96ca176d345831beb490fc6a7b2c76a Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 5 Dec 2022 19:07:06 +0100 Subject: [PATCH 107/109] Send diagnostic summaries synchronously --- crates/collab/src/rpc.rs | 10 +++---- crates/project/src/worktree.rs | 48 ++++++++++++++++------------------ crates/rpc/src/proto.rs | 1 - 3 files changed, 26 insertions(+), 33 deletions(-) diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 79544de6fbdcc959f82c79c7cd830336cc6e2696..0136a5fec6b1326aace79dc11ea1f6f310c3b705 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -201,7 +201,7 @@ impl Server { .add_request_handler(update_worktree) .add_message_handler(start_language_server) .add_message_handler(update_language_server) - .add_request_handler(update_diagnostic_summary) + .add_message_handler(update_diagnostic_summary) .add_request_handler(forward_project_request::) .add_request_handler(forward_project_request::) .add_request_handler(forward_project_request::) @@ -1187,14 +1187,13 @@ async fn update_worktree( } async fn update_diagnostic_summary( - request: proto::UpdateDiagnosticSummary, - response: Response, + message: proto::UpdateDiagnosticSummary, session: Session, ) -> Result<()> { let guest_connection_ids = session .db() .await - .update_diagnostic_summary(&request, session.connection_id) + .update_diagnostic_summary(&message, session.connection_id) .await?; broadcast( @@ -1203,11 +1202,10 @@ async fn update_diagnostic_summary( |connection_id| { session .peer - .forward_send(session.connection_id, connection_id, request.clone()) + .forward_send(session.connection_id, connection_id, message.clone()) }, ); - response.send(proto::Ack {})?; Ok(()) } diff --git a/crates/project/src/worktree.rs b/crates/project/src/worktree.rs index 409f65f78655420c32bf455b9d54b4e695ec62d5..4781e17541a936e7e58cd19fc324a974da918076 100644 --- a/crates/project/src/worktree.rs +++ b/crates/project/src/worktree.rs @@ -168,9 +168,7 @@ enum ScanState { struct ShareState { project_id: u64, snapshots_tx: watch::Sender, - diagnostic_summaries_tx: mpsc::UnboundedSender<(Arc, DiagnosticSummary)>, _maintain_remote_snapshot: Task>, - _maintain_remote_diagnostic_summaries: Task<()>, } pub enum Event { @@ -532,9 +530,18 @@ impl LocalWorktree { let updated = !old_summary.is_empty() || !new_summary.is_empty(); if updated { if let Some(share) = self.share.as_ref() { - let _ = share - .diagnostic_summaries_tx - .unbounded_send((worktree_path.clone(), new_summary)); + self.client + .send(proto::UpdateDiagnosticSummary { + project_id: share.project_id, + worktree_id: self.id().to_proto(), + summary: Some(proto::DiagnosticSummary { + path: worktree_path.to_string_lossy().to_string(), + language_server_id: language_server_id as u64, + error_count: new_summary.error_count as u32, + warning_count: new_summary.warning_count as u32, + }), + }) + .log_err(); } } @@ -968,6 +975,16 @@ impl LocalWorktree { let (snapshots_tx, mut snapshots_rx) = watch::channel_with(self.snapshot()); let worktree_id = cx.model_id() as u64; + for (path, summary) in self.diagnostic_summaries.iter() { + if let Err(e) = self.client.send(proto::UpdateDiagnosticSummary { + project_id, + worktree_id, + summary: Some(summary.to_proto(&path.0)), + }) { + return Task::ready(Err(e)); + } + } + let maintain_remote_snapshot = cx.background().spawn({ let rpc = self.client.clone(); async move { @@ -1017,31 +1034,10 @@ impl LocalWorktree { .log_err() }); - let (diagnostic_summaries_tx, mut diagnostic_summaries_rx) = mpsc::unbounded(); - for (path, summary) in self.diagnostic_summaries.iter() { - let _ = diagnostic_summaries_tx.unbounded_send((path.0.clone(), summary.clone())); - } - let maintain_remote_diagnostic_summaries = cx.background().spawn({ - let rpc = self.client.clone(); - async move { - while let Some((path, summary)) = diagnostic_summaries_rx.next().await { - rpc.request(proto::UpdateDiagnosticSummary { - project_id, - worktree_id, - summary: Some(summary.to_proto(&path)), - }) - .await - .log_err(); - } - } - }); - self.share = Some(ShareState { project_id, snapshots_tx, - diagnostic_summaries_tx, _maintain_remote_snapshot: maintain_remote_snapshot, - _maintain_remote_diagnostic_summaries: maintain_remote_diagnostic_summaries, }); } diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index 50f3c57f2a6b3c5bd9bc6798e468df7a541a2f07..6d9bc9a0aa348af8c1a14f442323fcf06064688e 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -228,7 +228,6 @@ request_messages!( (ShareProject, ShareProjectResponse), (Test, Test), (UpdateBuffer, Ack), - (UpdateDiagnosticSummary, Ack), (UpdateParticipantLocation, Ack), (UpdateProject, Ack), (UpdateWorktree, Ack), From cd08d289aa8e9790d7ed4b1acf55e59c600ddc01 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 5 Dec 2022 19:45:56 +0100 Subject: [PATCH 108/109] Fix warnings --- crates/workspace/src/workspace.rs | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index 45de94b6030b090136137d6a689678579839f957..a0c353b3f808bf1f1a5c9a9909f2047139916449 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -1031,8 +1031,10 @@ impl Workspace { RemoveWorktreeFromProject(worktree_id): &RemoveWorktreeFromProject, cx: &mut ViewContext, ) { - self.project + let future = self + .project .update(cx, |project, cx| project.remove_worktree(*worktree_id, cx)); + cx.foreground().spawn(future).detach(); } fn project_path_for_path( @@ -2862,9 +2864,9 @@ mod tests { ); // Remove a project folder - project.update(cx, |project, cx| { - project.remove_worktree(worktree_id, cx); - }); + project + .update(cx, |project, cx| project.remove_worktree(worktree_id, cx)) + .await; assert_eq!( cx.current_window_title(window_id).as_deref(), Some("one.txt — root2") From 744714b478701e5475bfc28deaf0b94276ae9ff4 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Tue, 6 Dec 2022 09:07:25 +0100 Subject: [PATCH 109/109] Remove unused `UserId` import from seed script --- crates/collab/src/bin/seed.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/collab/src/bin/seed.rs b/crates/collab/src/bin/seed.rs index 5ddacf6d64d7f487fb8c36e02fae886f528e9873..dfd2ae3a21656fa4b1e8273de748f2765612dc10 100644 --- a/crates/collab/src/bin/seed.rs +++ b/crates/collab/src/bin/seed.rs @@ -1,5 +1,5 @@ use collab::db; -use db::{ConnectOptions, Database, UserId}; +use db::{ConnectOptions, Database}; use serde::{de::DeserializeOwned, Deserialize}; use std::fmt::Write;