diff --git a/server/src/rpc.rs b/server/src/rpc.rs index 961bbffae2a90e7a0d49729c0193a744901624e1..009fc9bd1ed8cee0bbec78195ce7390e3c83de50 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -1,3 +1,5 @@ +mod store; + use super::{ auth, db::{ChannelId, MessageId, UserId}, @@ -8,16 +10,17 @@ use anyhow::anyhow; use async_std::{sync::RwLock, task}; use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream}; use futures::{future::BoxFuture, FutureExt}; -use postage::{mpsc, prelude::Sink as _, prelude::Stream as _}; +use postage::{broadcast, mpsc, prelude::Sink as _, prelude::Stream as _}; use sha1::{Digest as _, Sha1}; use std::{ any::TypeId, - collections::{hash_map, HashMap, HashSet}, + collections::{HashMap, HashSet}, future::Future, mem, sync::Arc, time::Instant, }; +use store::{ReplicaId, Store, Worktree}; use surf::StatusCode; use tide::log; use tide::{ @@ -30,8 +33,6 @@ use zrpc::{ Connection, ConnectionId, Peer, TypedEnvelope, }; -type ReplicaId = u16; - type MessageHandler = Box< dyn Send + Sync @@ -40,46 +41,12 @@ type MessageHandler = Box< pub struct Server { peer: Arc, - state: RwLock, + store: RwLock, app_state: Arc, handlers: HashMap, notifications: Option>, } -#[derive(Default)] -struct ServerState { - connections: HashMap, - connections_by_user_id: HashMap>, - worktrees: HashMap, - visible_worktrees_by_user_id: HashMap>, - channels: HashMap, - next_worktree_id: u64, -} - -struct ConnectionState { - user_id: UserId, - worktrees: HashSet, - channels: HashSet, -} - -struct Worktree { - host_connection_id: ConnectionId, - collaborator_user_ids: Vec, - root_name: String, - share: Option, -} - -struct WorktreeShare { - guest_connection_ids: HashMap, - active_replica_ids: HashSet, - entries: HashMap, -} - -#[derive(Default)] -struct Channel { - connection_ids: HashSet, -} - const MESSAGE_COUNT_PER_PAGE: usize = 100; const MAX_MESSAGE_LEN: usize = 1024; @@ -92,7 +59,7 @@ impl Server { let mut server = Self { peer, app_state, - state: Default::default(), + store: Default::default(), handlers: Default::default(), notifications, }; @@ -100,7 +67,7 @@ impl Server { server .add_handler(Server::ping) .add_handler(Server::open_worktree) - .add_handler(Server::handle_close_worktree) + .add_handler(Server::close_worktree) .add_handler(Server::share_worktree) .add_handler(Server::unshare_worktree) .add_handler(Server::join_worktree) @@ -149,7 +116,10 @@ impl Server { async move { let (connection_id, handle_io, mut incoming_rx) = this.peer.add_connection(connection).await; - this.add_connection(connection_id, user_id).await; + this.store + .write() + .await + .add_connection(connection_id, user_id); if let Err(err) = this.update_collaborators_for_users(&[user_id]).await { log::error!("error updating collaborators for {:?}: {}", user_id, err); } @@ -197,61 +167,40 @@ impl Server { } } - async fn sign_out(self: &Arc, connection_id: zrpc::ConnectionId) -> tide::Result<()> { + async fn sign_out(self: &Arc, connection_id: ConnectionId) -> tide::Result<()> { self.peer.disconnect(connection_id).await; - self.remove_connection(connection_id).await?; - Ok(()) - } + let removed_connection = self.store.write().await.remove_connection(connection_id)?; - // Add a new connection associated with a given user. - async fn add_connection(&self, connection_id: ConnectionId, user_id: UserId) { - let mut state = self.state.write().await; - state.connections.insert( - connection_id, - ConnectionState { - user_id, - worktrees: Default::default(), - channels: Default::default(), - }, - ); - state - .connections_by_user_id - .entry(user_id) - .or_default() - .insert(connection_id); - } - - // Remove the given connection and its association with any worktrees. - async fn remove_connection( - self: &Arc, - connection_id: ConnectionId, - ) -> tide::Result<()> { - let mut worktree_ids = Vec::new(); - let mut state = self.state.write().await; - if let Some(connection) = state.connections.remove(&connection_id) { - worktree_ids = connection.worktrees.into_iter().collect(); - - for channel_id in connection.channels { - if let Some(channel) = state.channels.get_mut(&channel_id) { - channel.connection_ids.remove(&connection_id); - } - } - - let user_connections = state - .connections_by_user_id - .get_mut(&connection.user_id) - .unwrap(); - user_connections.remove(&connection_id); - if user_connections.is_empty() { - state.connections_by_user_id.remove(&connection.user_id); + for (worktree_id, worktree) in removed_connection.hosted_worktrees { + if let Some(share) = worktree.share { + broadcast( + connection_id, + share.guest_connection_ids.keys().copied().collect(), + |conn_id| { + self.peer + .send(conn_id, proto::UnshareWorktree { worktree_id }) + }, + ) + .await?; } } - drop(state); - for worktree_id in worktree_ids { - self.close_worktree(worktree_id, connection_id).await?; + for (worktree_id, peer_ids) in removed_connection.guest_worktree_ids { + broadcast(connection_id, peer_ids, |conn_id| { + self.peer.send( + conn_id, + proto::RemovePeer { + worktree_id, + peer_id: connection_id.0, + }, + ) + }) + .await?; } + self.update_collaborators_for_users(removed_connection.collaborator_ids.iter()) + .await; + Ok(()) } @@ -266,7 +215,7 @@ impl Server { ) -> tide::Result<()> { let receipt = request.receipt(); let host_user_id = self - .state + .store .read() .await .user_id_for_connection(request.sender_id)?; @@ -289,7 +238,7 @@ impl Server { } let collaborator_user_ids = collaborator_user_ids.into_iter().collect::>(); - let worktree_id = self.state.write().await.add_worktree(Worktree { + let worktree_id = self.store.write().await.add_worktree(Worktree { host_connection_id: request.sender_id, collaborator_user_ids: collaborator_user_ids.clone(), root_name: request.payload.root_name, @@ -305,6 +254,33 @@ impl Server { Ok(()) } + async fn close_worktree( + self: Arc, + request: TypedEnvelope, + ) -> tide::Result<()> { + let worktree_id = request.payload.worktree_id; + let worktree = self + .store + .write() + .await + .remove_worktree(worktree_id, request.sender_id)?; + + if let Some(share) = worktree.share { + broadcast( + request.sender_id, + share.guest_connection_ids.keys().copied().collect(), + |conn_id| { + self.peer + .send(conn_id, proto::UnshareWorktree { worktree_id }) + }, + ) + .await?; + } + self.update_collaborators_for_users(&worktree.collaborator_user_ids) + .await?; + Ok(()) + } + async fn share_worktree( self: Arc, mut request: TypedEnvelope, @@ -319,16 +295,12 @@ impl Server { .map(|entry| (entry.id, entry)) .collect(); - let mut state = self.state.write().await; - if let Some(worktree) = state.worktrees.get_mut(&worktree.id) { - worktree.share = Some(WorktreeShare { - guest_connection_ids: Default::default(), - active_replica_ids: Default::default(), - entries, - }); - let collaborator_user_ids = worktree.collaborator_user_ids.clone(); - - drop(state); + if let Some(collaborator_user_ids) = + self.store + .write() + .await + .share_worktree(worktree.id, request.sender_id, entries) + { self.peer .respond(request.receipt(), proto::ShareWorktreeResponse {}) .await?; @@ -352,26 +324,11 @@ impl Server { request: TypedEnvelope, ) -> tide::Result<()> { let worktree_id = request.payload.worktree_id; - - let connection_ids; - let collaborator_user_ids; - { - let mut state = self.state.write().await; - let worktree = state.write_worktree(worktree_id, request.sender_id)?; - if worktree.host_connection_id != request.sender_id { - return Err(anyhow!("no such worktree"))?; - } - - connection_ids = worktree.connection_ids(); - collaborator_user_ids = worktree.collaborator_user_ids.clone(); - - worktree.share.take(); - for connection_id in &connection_ids { - if let Some(connection) = state.connections.get_mut(connection_id) { - connection.worktrees.remove(&worktree_id); - } - } - } + let (connection_ids, collaborator_user_ids) = self + .store + .write() + .await + .unshare_worktree(worktree_id, request.sender_id)?; broadcast(request.sender_id, connection_ids, |conn_id| { self.peer @@ -390,7 +347,7 @@ impl Server { ) -> tide::Result<()> { let worktree_id = request.payload.worktree_id; let user_id = self - .state + .store .read() .await .user_id_for_connection(request.sender_id)?; @@ -398,7 +355,7 @@ impl Server { let response; let connection_ids; let collaborator_user_ids; - let mut state = self.state.write().await; + let mut state = self.store.write().await; match state.join_worktree(request.sender_id, user_id, worktree_id) { Ok((peer_replica_id, worktree)) => { let share = worktree.share()?; @@ -462,48 +419,17 @@ impl Server { Ok(()) } - async fn handle_close_worktree( - self: Arc, - request: TypedEnvelope, - ) -> tide::Result<()> { - self.close_worktree(request.payload.worktree_id, request.sender_id) - .await - } - - async fn close_worktree( + async fn leave_worktree( self: &Arc, worktree_id: u64, sender_conn_id: ConnectionId, ) -> tide::Result<()> { - let connection_ids; - let collaborator_user_ids; - let mut is_host = false; - let mut is_guest = false; + if let Some((connection_ids, collaborator_ids)) = self + .store + .write() + .await + .leave_worktree(sender_conn_id, worktree_id) { - let mut state = self.state.write().await; - let worktree = state.write_worktree(worktree_id, sender_conn_id)?; - connection_ids = worktree.connection_ids(); - collaborator_user_ids = worktree.collaborator_user_ids.clone(); - - if worktree.host_connection_id == sender_conn_id { - is_host = true; - state.remove_worktree(worktree_id); - } else { - let share = worktree.share_mut()?; - if let Some(replica_id) = share.guest_connection_ids.remove(&sender_conn_id) { - is_guest = true; - share.active_replica_ids.remove(&replica_id); - } - } - } - - if is_host { - broadcast(sender_conn_id, connection_ids, |conn_id| { - self.peer - .send(conn_id, proto::UnshareWorktree { worktree_id }) - }) - .await?; - } else if is_guest { broadcast(sender_conn_id, connection_ids, |conn_id| { self.peer.send( conn_id, @@ -513,10 +439,10 @@ impl Server { }, ) }) - .await? - } - self.update_collaborators_for_users(&collaborator_user_ids) .await?; + self.update_collaborators_for_users(&collaborator_ids) + .await?; + } Ok(()) } @@ -524,22 +450,19 @@ impl Server { self: Arc, request: TypedEnvelope, ) -> tide::Result<()> { - { - let mut state = self.state.write().await; - let worktree = state.write_worktree(request.payload.worktree_id, request.sender_id)?; - let share = worktree.share_mut()?; - - for entry_id in &request.payload.removed_entries { - share.entries.remove(&entry_id); - } - - for entry in &request.payload.updated_entries { - share.entries.insert(entry.id, entry.clone()); - } - } + let connection_ids = self.store.write().await.update_worktree( + request.sender_id, + request.payload.worktree_id, + &request.payload.removed_entries, + &request.payload.updated_entries, + )?; + + broadcast(request.sender_id, connection_ids, |connection_id| { + self.peer + .forward_send(request.sender_id, connection_id, request.payload.clone()) + }) + .await?; - self.broadcast_in_worktree(request.payload.worktree_id, &request) - .await?; Ok(()) } @@ -548,14 +471,11 @@ impl Server { request: TypedEnvelope, ) -> tide::Result<()> { let receipt = request.receipt(); - let worktree_id = request.payload.worktree_id; let host_connection_id = self - .state + .store .read() .await - .read_worktree(worktree_id, request.sender_id)? - .host_connection_id; - + .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?; let response = self .peer .forward_request(request.sender_id, host_connection_id, request.payload) @@ -569,16 +489,13 @@ impl Server { request: TypedEnvelope, ) -> tide::Result<()> { let host_connection_id = self - .state + .store .read() .await - .read_worktree(request.payload.worktree_id, request.sender_id)? - .host_connection_id; - + .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?; self.peer .forward_send(request.sender_id, host_connection_id, request.payload) .await?; - Ok(()) } @@ -589,15 +506,11 @@ impl Server { let host; let guests; { - let state = self.state.read().await; - let worktree = state.read_worktree(request.payload.worktree_id, request.sender_id)?; - host = worktree.host_connection_id; - guests = worktree - .share()? - .guest_connection_ids - .keys() - .copied() - .collect::>(); + let state = self.store.read().await; + host = state + .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?; + guests = state + .worktree_guest_connection_ids(request.sender_id, request.payload.worktree_id)?; } let sender = request.sender_id; @@ -627,8 +540,18 @@ impl Server { self: Arc, request: TypedEnvelope, ) -> tide::Result<()> { - self.broadcast_in_worktree(request.payload.worktree_id, &request) - .await?; + broadcast( + request.sender_id, + self.store + .read() + .await + .worktree_connection_ids(request.sender_id, request.payload.worktree_id)?, + |connection_id| { + self.peer + .forward_send(request.sender_id, connection_id, request.payload.clone()) + }, + ) + .await?; self.peer.respond(request.receipt(), proto::Ack {}).await?; Ok(()) } @@ -637,8 +560,19 @@ impl Server { self: Arc, request: TypedEnvelope, ) -> tide::Result<()> { - self.broadcast_in_worktree(request.payload.worktree_id, &request) - .await + broadcast( + request.sender_id, + self.store + .read() + .await + .worktree_connection_ids(request.sender_id, request.payload.worktree_id)?, + |connection_id| { + self.peer + .forward_send(request.sender_id, connection_id, request.payload.clone()) + }, + ) + .await?; + Ok(()) } async fn get_channels( @@ -646,7 +580,7 @@ impl Server { request: TypedEnvelope, ) -> tide::Result<()> { let user_id = self - .state + .store .read() .await .user_id_for_connection(request.sender_id)?; @@ -698,45 +632,10 @@ impl Server { ) -> tide::Result<()> { let mut send_futures = Vec::new(); - let state = self.state.read().await; + let state = self.store.read().await; for user_id in user_ids { - let mut collaborators = HashMap::new(); - for worktree_id in state - .visible_worktrees_by_user_id - .get(&user_id) - .unwrap_or(&HashSet::new()) - { - let worktree = &state.worktrees[worktree_id]; - - let mut guests = HashSet::new(); - if let Ok(share) = worktree.share() { - for guest_connection_id in share.guest_connection_ids.keys() { - let user_id = state - .user_id_for_connection(*guest_connection_id) - .context("stale worktree guest connection")?; - guests.insert(user_id.to_proto()); - } - } - - let host_user_id = state - .user_id_for_connection(worktree.host_connection_id) - .context("stale worktree host connection")?; - let host = - collaborators - .entry(host_user_id) - .or_insert_with(|| proto::Collaborator { - user_id: host_user_id.to_proto(), - worktrees: Vec::new(), - }); - host.worktrees.push(proto::WorktreeMetadata { - root_name: worktree.root_name.clone(), - is_shared: worktree.share().is_ok(), - participants: guests.into_iter().collect(), - }); - } - - let collaborators = collaborators.into_values().collect::>(); - for connection_id in state.user_connection_ids(*user_id) { + let collaborators = state.collaborators_for_user(*user_id); + for connection_id in state.connection_ids_for_user(*user_id) { send_futures.push(self.peer.send( connection_id, proto::UpdateCollaborators { @@ -757,7 +656,7 @@ impl Server { request: TypedEnvelope, ) -> tide::Result<()> { let user_id = self - .state + .store .read() .await .user_id_for_connection(request.sender_id)?; @@ -771,7 +670,7 @@ impl Server { Err(anyhow!("access denied"))?; } - self.state + self.store .write() .await .join_channel(request.sender_id, channel_id); @@ -806,7 +705,7 @@ impl Server { request: TypedEnvelope, ) -> tide::Result<()> { let user_id = self - .state + .store .read() .await .user_id_for_connection(request.sender_id)?; @@ -820,7 +719,7 @@ impl Server { Err(anyhow!("access denied"))?; } - self.state + self.store .write() .await .leave_channel(request.sender_id, channel_id); @@ -837,10 +736,10 @@ impl Server { let user_id; let connection_ids; { - let state = self.state.read().await; + let state = self.store.read().await; user_id = state.user_id_for_connection(request.sender_id)?; - if let Some(channel) = state.channels.get(&channel_id) { - connection_ids = channel.connection_ids(); + if let Some(ids) = state.channel_connection_ids(channel_id) { + connection_ids = ids; } else { return Ok(()); } @@ -925,7 +824,7 @@ impl Server { request: TypedEnvelope, ) -> tide::Result<()> { let user_id = self - .state + .store .read() .await .user_id_for_connection(request.sender_id)?; @@ -968,27 +867,6 @@ impl Server { .await?; Ok(()) } - - async fn broadcast_in_worktree( - &self, - worktree_id: u64, - message: &TypedEnvelope, - ) -> tide::Result<()> { - let connection_ids = self - .state - .read() - .await - .read_worktree(worktree_id, message.sender_id)? - .connection_ids(); - - broadcast(message.sender_id, connection_ids, |conn_id| { - self.peer - .forward_send(message.sender_id, conn_id, message.payload.clone()) - }) - .await?; - - Ok(()) - } } pub async fn broadcast( @@ -1008,292 +886,6 @@ where Ok(()) } -impl ServerState { - fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) { - if let Some(connection) = self.connections.get_mut(&connection_id) { - connection.channels.insert(channel_id); - self.channels - .entry(channel_id) - .or_default() - .connection_ids - .insert(connection_id); - } - } - - fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) { - if let Some(connection) = self.connections.get_mut(&connection_id) { - connection.channels.remove(&channel_id); - if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) { - entry.get_mut().connection_ids.remove(&connection_id); - if entry.get_mut().connection_ids.is_empty() { - entry.remove(); - } - } - } - } - - fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result { - Ok(self - .connections - .get(&connection_id) - .ok_or_else(|| anyhow!("unknown connection"))? - .user_id) - } - - fn user_connection_ids<'a>( - &'a self, - user_id: UserId, - ) -> impl 'a + Iterator { - self.connections_by_user_id - .get(&user_id) - .into_iter() - .flatten() - .copied() - } - - // Add the given connection as a guest of the given worktree - fn join_worktree( - &mut self, - connection_id: ConnectionId, - user_id: UserId, - worktree_id: u64, - ) -> tide::Result<(ReplicaId, &Worktree)> { - let connection = self - .connections - .get_mut(&connection_id) - .ok_or_else(|| anyhow!("no such connection"))?; - let worktree = self - .worktrees - .get_mut(&worktree_id) - .ok_or_else(|| anyhow!("no such worktree"))?; - if !worktree.collaborator_user_ids.contains(&user_id) { - Err(anyhow!("no such worktree"))?; - } - - let share = worktree.share_mut()?; - connection.worktrees.insert(worktree_id); - - let mut replica_id = 1; - while share.active_replica_ids.contains(&replica_id) { - replica_id += 1; - } - share.active_replica_ids.insert(replica_id); - share.guest_connection_ids.insert(connection_id, replica_id); - return Ok((replica_id, worktree)); - } - - fn read_worktree( - &self, - worktree_id: u64, - connection_id: ConnectionId, - ) -> tide::Result<&Worktree> { - let worktree = self - .worktrees - .get(&worktree_id) - .ok_or_else(|| anyhow!("worktree not found"))?; - - if worktree.host_connection_id == connection_id - || worktree - .share()? - .guest_connection_ids - .contains_key(&connection_id) - { - Ok(worktree) - } else { - Err(anyhow!( - "{} is not a member of worktree {}", - connection_id, - worktree_id - ))? - } - } - - fn write_worktree( - &mut self, - worktree_id: u64, - connection_id: ConnectionId, - ) -> tide::Result<&mut Worktree> { - let worktree = self - .worktrees - .get_mut(&worktree_id) - .ok_or_else(|| anyhow!("worktree not found"))?; - - if worktree.host_connection_id == connection_id - || worktree.share.as_ref().map_or(false, |share| { - share.guest_connection_ids.contains_key(&connection_id) - }) - { - Ok(worktree) - } else { - Err(anyhow!( - "{} is not a member of worktree {}", - connection_id, - worktree_id - ))? - } - } - - fn add_worktree(&mut self, worktree: Worktree) -> u64 { - let worktree_id = self.next_worktree_id; - for collaborator_user_id in &worktree.collaborator_user_ids { - self.visible_worktrees_by_user_id - .entry(*collaborator_user_id) - .or_default() - .insert(worktree_id); - } - self.next_worktree_id += 1; - if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) { - connection.worktrees.insert(worktree_id); - } - self.worktrees.insert(worktree_id, worktree); - - #[cfg(test)] - self.check_invariants(); - - worktree_id - } - - fn remove_worktree(&mut self, worktree_id: u64) { - let worktree = self.worktrees.remove(&worktree_id).unwrap(); - if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) { - connection.worktrees.remove(&worktree_id); - } - if let Some(share) = worktree.share { - for connection_id in share.guest_connection_ids.keys() { - if let Some(connection) = self.connections.get_mut(connection_id) { - connection.worktrees.remove(&worktree_id); - } - } - } - for collaborator_user_id in worktree.collaborator_user_ids { - if let Some(visible_worktrees) = self - .visible_worktrees_by_user_id - .get_mut(&collaborator_user_id) - { - visible_worktrees.remove(&worktree_id); - } - } - - #[cfg(test)] - self.check_invariants(); - } - - #[cfg(test)] - fn check_invariants(&self) { - for (connection_id, connection) in &self.connections { - for worktree_id in &connection.worktrees { - let worktree = &self.worktrees.get(&worktree_id).unwrap(); - if worktree.host_connection_id != *connection_id { - assert!(worktree - .share() - .unwrap() - .guest_connection_ids - .contains_key(connection_id)); - } - } - for channel_id in &connection.channels { - let channel = self.channels.get(channel_id).unwrap(); - assert!(channel.connection_ids.contains(connection_id)); - } - assert!(self - .connections_by_user_id - .get(&connection.user_id) - .unwrap() - .contains(connection_id)); - } - - for (user_id, connection_ids) in &self.connections_by_user_id { - for connection_id in connection_ids { - assert_eq!( - self.connections.get(connection_id).unwrap().user_id, - *user_id - ); - } - } - - for (worktree_id, worktree) in &self.worktrees { - let host_connection = self.connections.get(&worktree.host_connection_id).unwrap(); - assert!(host_connection.worktrees.contains(worktree_id)); - - for collaborator_id in &worktree.collaborator_user_ids { - let visible_worktree_ids = self - .visible_worktrees_by_user_id - .get(collaborator_id) - .unwrap(); - assert!(visible_worktree_ids.contains(worktree_id)); - } - - if let Some(share) = &worktree.share { - for guest_connection_id in share.guest_connection_ids.keys() { - let guest_connection = self.connections.get(guest_connection_id).unwrap(); - assert!(guest_connection.worktrees.contains(worktree_id)); - } - assert_eq!( - share.active_replica_ids.len(), - share.guest_connection_ids.len(), - ); - assert_eq!( - share.active_replica_ids, - share - .guest_connection_ids - .values() - .copied() - .collect::>(), - ); - } - } - - for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id { - for worktree_id in visible_worktree_ids { - let worktree = self.worktrees.get(worktree_id).unwrap(); - assert!(worktree.collaborator_user_ids.contains(user_id)); - } - } - - for (channel_id, channel) in &self.channels { - for connection_id in &channel.connection_ids { - let connection = self.connections.get(connection_id).unwrap(); - assert!(connection.channels.contains(channel_id)); - } - } - } -} - -impl Worktree { - pub fn connection_ids(&self) -> Vec { - if let Some(share) = &self.share { - share - .guest_connection_ids - .keys() - .copied() - .chain(Some(self.host_connection_id)) - .collect() - } else { - vec![self.host_connection_id] - } - } - - fn share(&self) -> tide::Result<&WorktreeShare> { - Ok(self - .share - .as_ref() - .ok_or_else(|| anyhow!("worktree is not shared"))?) - } - - fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> { - Ok(self - .share - .as_mut() - .ok_or_else(|| anyhow!("worktree is not shared"))?) - } -} - -impl Channel { - fn connection_ids(&self) -> Vec { - self.connection_ids.iter().copied().collect() - } -} - pub fn add_routes(app: &mut tide::Server>, rpc: &Arc) { let server = Server::new(app.state().clone(), rpc.clone(), None); app.at("/rpc").with(auth::VerifyToken).get(move |request: Request>| { @@ -2477,16 +2069,16 @@ mod tests { }) } - async fn state<'a>(&'a self) -> RwLockReadGuard<'a, ServerState> { - self.server.state.read().await + async fn state<'a>(&'a self) -> RwLockReadGuard<'a, Store> { + self.server.store.read().await } async fn condition(&mut self, mut predicate: F) where - F: FnMut(&ServerState) -> bool, + F: FnMut(&Store) -> bool, { async_std::future::timeout(Duration::from_millis(500), async { - while !(predicate)(&*self.server.state.read().await) { + while !(predicate)(&*self.server.store.read().await) { self.notifications.recv().await; } }) diff --git a/server/src/rpc/store.rs b/server/src/rpc/store.rs new file mode 100644 index 0000000000000000000000000000000000000000..c7a6c2b166d716f0b22769f053f47af2dd2595da --- /dev/null +++ b/server/src/rpc/store.rs @@ -0,0 +1,574 @@ +use crate::db::{ChannelId, MessageId, UserId}; +use crate::errors::TideResultExt; +use anyhow::anyhow; +use std::collections::{hash_map, HashMap, HashSet}; +use zrpc::{proto, ConnectionId}; + +#[derive(Default)] +pub struct Store { + connections: HashMap, + connections_by_user_id: HashMap>, + worktrees: HashMap, + visible_worktrees_by_user_id: HashMap>, + channels: HashMap, + next_worktree_id: u64, +} + +struct ConnectionState { + user_id: UserId, + worktrees: HashSet, + channels: HashSet, +} + +pub struct Worktree { + pub host_connection_id: ConnectionId, + pub collaborator_user_ids: Vec, + pub root_name: String, + pub share: Option, +} + +struct WorktreeShare { + pub guest_connection_ids: HashMap, + pub active_replica_ids: HashSet, + pub entries: HashMap, +} + +#[derive(Default)] +struct Channel { + connection_ids: HashSet, +} + +pub type ReplicaId = u16; + +#[derive(Default)] +pub struct RemovedConnectionState { + pub hosted_worktrees: HashMap, + pub guest_worktree_ids: HashMap>, + pub collaborator_ids: HashSet, +} + +impl Store { + pub fn add_connection(&mut self, connection_id: ConnectionId, user_id: UserId) { + self.connections.insert( + connection_id, + ConnectionState { + user_id, + worktrees: Default::default(), + channels: Default::default(), + }, + ); + self.connections_by_user_id + .entry(user_id) + .or_default() + .insert(connection_id); + } + + pub fn remove_connection( + &mut self, + connection_id: ConnectionId, + ) -> tide::Result { + let connection = if let Some(connection) = self.connections.get(&connection_id) { + connection + } else { + return Err(anyhow!("no such connection"))?; + }; + + for channel_id in connection.channels { + if let Some(channel) = self.channels.get_mut(&channel_id) { + channel.connection_ids.remove(&connection_id); + } + } + + let user_connections = self + .connections_by_user_id + .get_mut(&connection.user_id) + .unwrap(); + user_connections.remove(&connection_id); + if user_connections.is_empty() { + self.connections_by_user_id.remove(&connection.user_id); + } + + let mut result = RemovedConnectionState::default(); + for worktree_id in connection.worktrees { + if let Ok(worktree) = self.remove_worktree(worktree_id, connection_id) { + result.hosted_worktrees.insert(worktree_id, worktree); + result + .collaborator_ids + .extend(worktree.collaborator_user_ids.iter().copied()); + } else { + if let Some(worktree) = self.worktrees.get(&worktree_id) { + result + .guest_worktree_ids + .insert(worktree_id, worktree.connection_ids()); + result + .collaborator_ids + .extend(worktree.collaborator_user_ids.iter().copied()); + } + } + } + + Ok(result) + } + + pub fn join_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) { + if let Some(connection) = self.connections.get_mut(&connection_id) { + connection.channels.insert(channel_id); + self.channels + .entry(channel_id) + .or_default() + .connection_ids + .insert(connection_id); + } + } + + pub fn leave_channel(&mut self, connection_id: ConnectionId, channel_id: ChannelId) { + if let Some(connection) = self.connections.get_mut(&connection_id) { + connection.channels.remove(&channel_id); + if let hash_map::Entry::Occupied(mut entry) = self.channels.entry(channel_id) { + entry.get_mut().connection_ids.remove(&connection_id); + if entry.get_mut().connection_ids.is_empty() { + entry.remove(); + } + } + } + } + + pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> tide::Result { + Ok(self + .connections + .get(&connection_id) + .ok_or_else(|| anyhow!("unknown connection"))? + .user_id) + } + + pub fn connection_ids_for_user<'a>( + &'a self, + user_id: UserId, + ) -> impl 'a + Iterator { + self.connections_by_user_id + .get(&user_id) + .into_iter() + .flatten() + .copied() + } + + pub fn collaborators_for_user(&self, user_id: UserId) -> Vec { + let mut collaborators = HashMap::new(); + for worktree_id in self + .visible_worktrees_by_user_id + .get(&user_id) + .unwrap_or(&HashSet::new()) + { + let worktree = &self.worktrees[worktree_id]; + + let mut guests = HashSet::new(); + if let Ok(share) = worktree.share() { + for guest_connection_id in share.guest_connection_ids.keys() { + if let Ok(user_id) = self.user_id_for_connection(*guest_connection_id) { + guests.insert(user_id.to_proto()); + } + } + } + + if let Ok(host_user_id) = self + .user_id_for_connection(worktree.host_connection_id) + .context("stale worktree host connection") + { + let host = + collaborators + .entry(host_user_id) + .or_insert_with(|| proto::Collaborator { + user_id: host_user_id.to_proto(), + worktrees: Vec::new(), + }); + host.worktrees.push(proto::WorktreeMetadata { + root_name: worktree.root_name.clone(), + is_shared: worktree.share().is_ok(), + participants: guests.into_iter().collect(), + }); + } + } + + collaborators.into_values().collect() + } + + pub fn add_worktree(&mut self, worktree: Worktree) -> u64 { + let worktree_id = self.next_worktree_id; + for collaborator_user_id in &worktree.collaborator_user_ids { + self.visible_worktrees_by_user_id + .entry(*collaborator_user_id) + .or_default() + .insert(worktree_id); + } + self.next_worktree_id += 1; + if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) { + connection.worktrees.insert(worktree_id); + } + self.worktrees.insert(worktree_id, worktree); + + #[cfg(test)] + self.check_invariants(); + + worktree_id + } + + pub fn remove_worktree( + &mut self, + worktree_id: u64, + acting_connection_id: ConnectionId, + ) -> tide::Result { + let worktree = if let hash_map::Entry::Occupied(e) = self.worktrees.entry(worktree_id) { + if e.get().host_connection_id != acting_connection_id { + Err(anyhow!("not your worktree"))?; + } + e.remove() + } else { + return Err(anyhow!("no such worktree"))?; + }; + + if let Some(connection) = self.connections.get_mut(&worktree.host_connection_id) { + connection.worktrees.remove(&worktree_id); + } + + if let Some(share) = worktree.share { + for connection_id in share.guest_connection_ids.keys() { + if let Some(connection) = self.connections.get_mut(connection_id) { + connection.worktrees.remove(&worktree_id); + } + } + } + + for collaborator_user_id in worktree.collaborator_user_ids { + if let Some(visible_worktrees) = self + .visible_worktrees_by_user_id + .get_mut(&collaborator_user_id) + { + visible_worktrees.remove(&worktree_id); + } + } + + #[cfg(test)] + self.check_invariants(); + + Ok(worktree) + } + + pub fn share_worktree( + &mut self, + worktree_id: u64, + connection_id: ConnectionId, + entries: HashMap, + ) -> Option> { + if let Some(worktree) = self.worktrees.get_mut(&worktree_id) { + if worktree.host_connection_id == connection_id { + worktree.share = Some(WorktreeShare { + guest_connection_ids: Default::default(), + active_replica_ids: Default::default(), + entries, + }); + return Some(worktree.collaborator_user_ids.clone()); + } + } + None + } + + pub fn unshare_worktree( + &mut self, + worktree_id: u64, + acting_connection_id: ConnectionId, + ) -> tide::Result<(Vec, Vec)> { + let worktree = if let Some(worktree) = self.worktrees.get_mut(&worktree_id) { + worktree + } else { + return Err(anyhow!("no such worktree"))?; + }; + + if worktree.host_connection_id != acting_connection_id { + return Err(anyhow!("not your worktree"))?; + } + + let connection_ids = worktree.connection_ids(); + + if let Some(share) = worktree.share.take() { + for connection_id in &connection_ids { + if let Some(connection) = self.connections.get_mut(connection_id) { + connection.worktrees.remove(&worktree_id); + } + } + Ok((connection_ids, worktree.collaborator_user_ids.clone())) + } else { + Err(anyhow!("worktree is not shared"))? + } + } + + pub fn join_worktree( + &mut self, + connection_id: ConnectionId, + user_id: UserId, + worktree_id: u64, + ) -> tide::Result<(ReplicaId, &Worktree)> { + let connection = self + .connections + .get_mut(&connection_id) + .ok_or_else(|| anyhow!("no such connection"))?; + let worktree = self + .worktrees + .get_mut(&worktree_id) + .and_then(|worktree| { + if worktree.collaborator_user_ids.contains(&user_id) { + Some(worktree) + } else { + None + } + }) + .ok_or_else(|| anyhow!("no such worktree"))?; + + let share = worktree.share_mut()?; + connection.worktrees.insert(worktree_id); + + let mut replica_id = 1; + while share.active_replica_ids.contains(&replica_id) { + replica_id += 1; + } + share.active_replica_ids.insert(replica_id); + share.guest_connection_ids.insert(connection_id, replica_id); + return Ok((replica_id, worktree)); + } + + pub fn leave_worktree( + &mut self, + connection_id: ConnectionId, + worktree_id: u64, + ) -> Option<(Vec, Vec)> { + let worktree = self.worktrees.get_mut(&worktree_id)?; + let share = worktree.share.as_mut()?; + let replica_id = share.guest_connection_ids.remove(&connection_id)?; + share.active_replica_ids.remove(&replica_id); + Some(( + worktree.connection_ids(), + worktree.collaborator_user_ids.clone(), + )) + } + + pub fn update_worktree( + &mut self, + connection_id: ConnectionId, + worktree_id: u64, + removed_entries: &[u64], + updated_entries: &[proto::Entry], + ) -> tide::Result> { + let worktree = self.write_worktree(worktree_id, connection_id)?; + let share = worktree.share_mut()?; + for entry_id in removed_entries { + share.entries.remove(&entry_id); + } + for entry in updated_entries { + share.entries.insert(entry.id, entry.clone()); + } + Ok(worktree.connection_ids()) + } + + pub fn worktree_host_connection_id( + &self, + connection_id: ConnectionId, + worktree_id: u64, + ) -> tide::Result { + Ok(self + .read_worktree(worktree_id, connection_id)? + .host_connection_id) + } + + pub fn worktree_guest_connection_ids( + &self, + connection_id: ConnectionId, + worktree_id: u64, + ) -> tide::Result> { + Ok(self + .read_worktree(worktree_id, connection_id)? + .share()? + .guest_connection_ids + .keys() + .copied() + .collect()) + } + + pub fn worktree_connection_ids( + &self, + connection_id: ConnectionId, + worktree_id: u64, + ) -> tide::Result> { + Ok(self + .read_worktree(worktree_id, connection_id)? + .connection_ids()) + } + + pub fn channel_connection_ids(&self, channel_id: ChannelId) -> Option> { + Some(self.channels.get(&channel_id)?.connection_ids()) + } + + fn read_worktree( + &self, + worktree_id: u64, + connection_id: ConnectionId, + ) -> tide::Result<&Worktree> { + let worktree = self + .worktrees + .get(&worktree_id) + .ok_or_else(|| anyhow!("worktree not found"))?; + + if worktree.host_connection_id == connection_id + || worktree + .share()? + .guest_connection_ids + .contains_key(&connection_id) + { + Ok(worktree) + } else { + Err(anyhow!( + "{} is not a member of worktree {}", + connection_id, + worktree_id + ))? + } + } + + fn write_worktree( + &mut self, + worktree_id: u64, + connection_id: ConnectionId, + ) -> tide::Result<&mut Worktree> { + let worktree = self + .worktrees + .get_mut(&worktree_id) + .ok_or_else(|| anyhow!("worktree not found"))?; + + if worktree.host_connection_id == connection_id + || worktree.share.as_ref().map_or(false, |share| { + share.guest_connection_ids.contains_key(&connection_id) + }) + { + Ok(worktree) + } else { + Err(anyhow!( + "{} is not a member of worktree {}", + connection_id, + worktree_id + ))? + } + } + + #[cfg(test)] + fn check_invariants(&self) { + for (connection_id, connection) in &self.connections { + for worktree_id in &connection.worktrees { + let worktree = &self.worktrees.get(&worktree_id).unwrap(); + if worktree.host_connection_id != *connection_id { + assert!(worktree + .share() + .unwrap() + .guest_connection_ids + .contains_key(connection_id)); + } + } + for channel_id in &connection.channels { + let channel = self.channels.get(channel_id).unwrap(); + assert!(channel.connection_ids.contains(connection_id)); + } + assert!(self + .connections_by_user_id + .get(&connection.user_id) + .unwrap() + .contains(connection_id)); + } + + for (user_id, connection_ids) in &self.connections_by_user_id { + for connection_id in connection_ids { + assert_eq!( + self.connections.get(connection_id).unwrap().user_id, + *user_id + ); + } + } + + for (worktree_id, worktree) in &self.worktrees { + let host_connection = self.connections.get(&worktree.host_connection_id).unwrap(); + assert!(host_connection.worktrees.contains(worktree_id)); + + for collaborator_id in &worktree.collaborator_user_ids { + let visible_worktree_ids = self + .visible_worktrees_by_user_id + .get(collaborator_id) + .unwrap(); + assert!(visible_worktree_ids.contains(worktree_id)); + } + + if let Some(share) = &worktree.share { + for guest_connection_id in share.guest_connection_ids.keys() { + let guest_connection = self.connections.get(guest_connection_id).unwrap(); + assert!(guest_connection.worktrees.contains(worktree_id)); + } + assert_eq!( + share.active_replica_ids.len(), + share.guest_connection_ids.len(), + ); + assert_eq!( + share.active_replica_ids, + share + .guest_connection_ids + .values() + .copied() + .collect::>(), + ); + } + } + + for (user_id, visible_worktree_ids) in &self.visible_worktrees_by_user_id { + for worktree_id in visible_worktree_ids { + let worktree = self.worktrees.get(worktree_id).unwrap(); + assert!(worktree.collaborator_user_ids.contains(user_id)); + } + } + + for (channel_id, channel) in &self.channels { + for connection_id in &channel.connection_ids { + let connection = self.connections.get(connection_id).unwrap(); + assert!(connection.channels.contains(channel_id)); + } + } + } +} + +impl Worktree { + pub fn connection_ids(&self) -> Vec { + if let Some(share) = &self.share { + share + .guest_connection_ids + .keys() + .copied() + .chain(Some(self.host_connection_id)) + .collect() + } else { + vec![self.host_connection_id] + } + } + + pub fn share(&self) -> tide::Result<&WorktreeShare> { + Ok(self + .share + .as_ref() + .ok_or_else(|| anyhow!("worktree is not shared"))?) + } + + fn share_mut(&mut self) -> tide::Result<&mut WorktreeShare> { + Ok(self + .share + .as_mut() + .ok_or_else(|| anyhow!("worktree is not shared"))?) + } +} + +impl Channel { + fn connection_ids(&self) -> Vec { + self.connection_ids.iter().copied().collect() + } +} diff --git a/zed/src/worktree.rs b/zed/src/worktree.rs index 93be5061341216e4e2fbe33d848a3033f53efc6f..7c766bc0d7c8abbf044f86cc2821f70d979fce73 100644 --- a/zed/src/worktree.rs +++ b/zed/src/worktree.rs @@ -66,21 +66,28 @@ impl Entity for Worktree { type Event = (); fn release(&mut self, cx: &mut MutableAppContext) { - let rpc = match self { - Self::Local(tree) => tree - .remote_id - .borrow() - .map(|remote_id| (tree.rpc.clone(), remote_id)), - Self::Remote(tree) => Some((tree.rpc.clone(), tree.remote_id)), - }; - - if let Some((rpc, worktree_id)) = rpc { - cx.spawn(|_| async move { - if let Err(err) = rpc.send(proto::CloseWorktree { worktree_id }).await { - log::error!("error closing worktree {}: {}", worktree_id, err); + match self { + Self::Local(tree) => { + if let Some(worktree_id) = *tree.remote_id.borrow() { + let rpc = tree.rpc.clone(); + cx.spawn(|_| async move { + if let Err(err) = rpc.send(proto::CloseWorktree { worktree_id }).await { + log::error!("error closing worktree: {}", err); + } + }) + .detach(); } - }) - .detach(); + } + Self::Remote(tree) => { + let rpc = tree.rpc.clone(); + let worktree_id = tree.remote_id; + cx.spawn(|_| async move { + if let Err(err) = rpc.send(proto::LeaveWorktree { worktree_id }).await { + log::error!("error closing worktree: {}", err); + } + }) + .detach(); + } } } } diff --git a/zrpc/proto/zed.proto b/zrpc/proto/zed.proto index 074bbe60938c833dfc457c6877e2a7c8aef0a9f1..340c0751aa3255a5fded2f7ab71a212c7322f9a0 100644 --- a/zrpc/proto/zed.proto +++ b/zrpc/proto/zed.proto @@ -39,6 +39,7 @@ message Envelope { OpenWorktreeResponse open_worktree_response = 34; UnshareWorktree unshare_worktree = 35; UpdateCollaborators update_collaborators = 36; + LeaveWorktree leave_worktree = 37; } } @@ -75,6 +76,10 @@ message JoinWorktree { uint64 worktree_id = 1; } +message LeaveWorktree { + uint64 worktree_id = 1; +} + message JoinWorktreeResponse { Worktree worktree = 2; uint32 replica_id = 3; diff --git a/zrpc/src/proto.rs b/zrpc/src/proto.rs index f094923af387accc6d76e1f0187739db68c799df..92fca53e28335680f2cf7227e0eea32a68a54e8b 100644 --- a/zrpc/src/proto.rs +++ b/zrpc/src/proto.rs @@ -139,6 +139,7 @@ messages!( JoinWorktree, JoinWorktreeResponse, LeaveChannel, + LeaveWorktree, OpenBuffer, OpenBufferResponse, OpenWorktree,