diff --git a/server/src/rpc.rs b/server/src/rpc.rs index 5aae943b58d4b70018f2fb77ad1f0223e2795784..68f1ff7d82033ef4e03aeb0fa4219d664ba70b6d 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -112,12 +112,11 @@ impl Server { addr: String, user_id: UserId, ) -> impl Future { - let this = self.clone(); + let mut this = self.clone(); async move { let (connection_id, handle_io, mut incoming_rx) = this.peer.add_connection(connection).await; - this.store - .write() + this.state_mut() .await .add_connection(connection_id, user_id); if let Err(err) = this.update_collaborators_for_users(&[user_id]).await { @@ -167,9 +166,9 @@ impl Server { } } - async fn sign_out(self: &Arc, connection_id: ConnectionId) -> tide::Result<()> { + async fn sign_out(self: &mut Arc, connection_id: ConnectionId) -> tide::Result<()> { self.peer.disconnect(connection_id).await; - let removed_connection = self.store.write().await.remove_connection(connection_id)?; + let removed_connection = self.state_mut().await.remove_connection(connection_id)?; for (worktree_id, worktree) in removed_connection.hosted_worktrees { if let Some(share) = worktree.share { @@ -210,13 +209,12 @@ impl Server { } async fn open_worktree( - self: Arc, + mut self: Arc, request: TypedEnvelope, ) -> tide::Result<()> { let receipt = request.receipt(); let host_user_id = self - .store - .read() + .state() .await .user_id_for_connection(request.sender_id)?; @@ -238,7 +236,7 @@ impl Server { } let collaborator_user_ids = collaborator_user_ids.into_iter().collect::>(); - let worktree_id = self.store.write().await.add_worktree(Worktree { + let worktree_id = self.state_mut().await.add_worktree(Worktree { host_connection_id: request.sender_id, collaborator_user_ids: collaborator_user_ids.clone(), root_name: request.payload.root_name, @@ -255,13 +253,12 @@ impl Server { } async fn close_worktree( - self: Arc, + mut self: Arc, request: TypedEnvelope, ) -> tide::Result<()> { let worktree_id = request.payload.worktree_id; let worktree = self - .store - .write() + .state_mut() .await .remove_worktree(worktree_id, request.sender_id)?; @@ -282,7 +279,7 @@ impl Server { } async fn share_worktree( - self: Arc, + mut self: Arc, mut request: TypedEnvelope, ) -> tide::Result<()> { let worktree = request @@ -296,8 +293,7 @@ impl Server { .collect(); let collaborator_user_ids = - self.store - .write() + self.state_mut() .await .share_worktree(worktree.id, request.sender_id, entries); if let Some(collaborator_user_ids) = collaborator_user_ids { @@ -320,13 +316,12 @@ impl Server { } async fn unshare_worktree( - self: Arc, + mut self: Arc, request: TypedEnvelope, ) -> tide::Result<()> { let worktree_id = request.payload.worktree_id; let worktree = self - .store - .write() + .state_mut() .await .unshare_worktree(worktree_id, request.sender_id)?; @@ -342,20 +337,16 @@ impl Server { } async fn join_worktree( - self: Arc, + mut self: Arc, request: TypedEnvelope, ) -> tide::Result<()> { let worktree_id = request.payload.worktree_id; let user_id = self - .store - .read() + .state() .await .user_id_for_connection(request.sender_id)?; - let response; - let connection_ids; - let collaborator_user_ids; - let mut state = self.store.write().await; + let mut state = self.state_mut().await; match state.join_worktree(request.sender_id, user_id, worktree_id) { Ok(JoinedWorktree { replica_id, @@ -376,7 +367,7 @@ impl Server { }); } } - response = proto::JoinWorktreeResponse { + let response = proto::JoinWorktreeResponse { worktree: Some(proto::Worktree { id: worktree_id, root_name: worktree.root_name.clone(), @@ -385,10 +376,29 @@ impl Server { replica_id: replica_id as u32, peers, }; - connection_ids = worktree.connection_ids(); - collaborator_user_ids = worktree.collaborator_user_ids.clone(); + let connection_ids = worktree.connection_ids(); + let collaborator_user_ids = worktree.collaborator_user_ids.clone(); + drop(state); + + broadcast(request.sender_id, connection_ids, |conn_id| { + self.peer.send( + conn_id, + proto::AddPeer { + worktree_id, + peer: Some(proto::Peer { + peer_id: request.sender_id.0, + replica_id: response.replica_id, + }), + }, + ) + }) + .await?; + self.peer.respond(request.receipt(), response).await?; + self.update_collaborators_for_users(&collaborator_user_ids) + .await?; } Err(error) => { + drop(state); self.peer .respond_with_error( request.receipt(), @@ -397,44 +407,23 @@ impl Server { }, ) .await?; - return Ok(()); } } - drop(state); - broadcast(request.sender_id, connection_ids, |conn_id| { - self.peer.send( - conn_id, - proto::AddPeer { - worktree_id, - peer: Some(proto::Peer { - peer_id: request.sender_id.0, - replica_id: response.replica_id, - }), - }, - ) - }) - .await?; - self.peer.respond(request.receipt(), response).await?; - self.update_collaborators_for_users(&collaborator_user_ids) - .await?; - Ok(()) } async fn leave_worktree( - self: Arc, + mut self: Arc, request: TypedEnvelope, ) -> tide::Result<()> { let sender_id = request.sender_id; let worktree_id = request.payload.worktree_id; - - if let Some(worktree) = self - .store - .write() + let worktree = self + .state_mut() .await - .leave_worktree(sender_id, worktree_id) - { + .leave_worktree(sender_id, worktree_id); + if let Some(worktree) = worktree { broadcast(sender_id, worktree.connection_ids, |conn_id| { self.peer.send( conn_id, @@ -452,10 +441,10 @@ impl Server { } async fn update_worktree( - self: Arc, + mut self: Arc, request: TypedEnvelope, ) -> tide::Result<()> { - let connection_ids = self.store.write().await.update_worktree( + let connection_ids = self.state_mut().await.update_worktree( request.sender_id, request.payload.worktree_id, &request.payload.removed_entries, @@ -477,8 +466,7 @@ impl Server { ) -> tide::Result<()> { let receipt = request.receipt(); let host_connection_id = self - .store - .read() + .state() .await .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?; let response = self @@ -494,8 +482,7 @@ impl Server { request: TypedEnvelope, ) -> tide::Result<()> { let host_connection_id = self - .store - .read() + .state() .await .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?; self.peer @@ -511,7 +498,7 @@ impl Server { let host; let guests; { - let state = self.store.read().await; + let state = self.state().await; host = state .worktree_host_connection_id(request.sender_id, request.payload.worktree_id)?; guests = state @@ -547,8 +534,7 @@ impl Server { ) -> tide::Result<()> { broadcast( request.sender_id, - self.store - .read() + self.state() .await .worktree_connection_ids(request.sender_id, request.payload.worktree_id)?, |connection_id| { @@ -585,8 +571,7 @@ impl Server { request: TypedEnvelope, ) -> tide::Result<()> { let user_id = self - .store - .read() + .state() .await .user_id_for_connection(request.sender_id)?; let channels = self.app_state.db.get_accessible_channels(user_id).await?; @@ -637,7 +622,7 @@ impl Server { ) -> tide::Result<()> { let mut send_futures = Vec::new(); - let state = self.store.read().await; + let state = self.state().await; for user_id in user_ids { let collaborators = state.collaborators_for_user(*user_id); for connection_id in state.connection_ids_for_user(*user_id) { @@ -657,12 +642,11 @@ impl Server { } async fn join_channel( - self: Arc, + mut self: Arc, request: TypedEnvelope, ) -> tide::Result<()> { let user_id = self - .store - .read() + .state() .await .user_id_for_connection(request.sender_id)?; let channel_id = ChannelId::from_proto(request.payload.channel_id); @@ -675,8 +659,7 @@ impl Server { Err(anyhow!("access denied"))?; } - self.store - .write() + self.state_mut() .await .join_channel(request.sender_id, channel_id); let messages = self @@ -706,12 +689,11 @@ impl Server { } async fn leave_channel( - self: Arc, + mut self: Arc, request: TypedEnvelope, ) -> tide::Result<()> { let user_id = self - .store - .read() + .state() .await .user_id_for_connection(request.sender_id)?; let channel_id = ChannelId::from_proto(request.payload.channel_id); @@ -724,8 +706,7 @@ impl Server { Err(anyhow!("access denied"))?; } - self.store - .write() + self.state_mut() .await .leave_channel(request.sender_id, channel_id); @@ -741,7 +722,7 @@ impl Server { let user_id; let connection_ids; { - let state = self.store.read().await; + let state = self.state().await; user_id = state.user_id_for_connection(request.sender_id)?; if let Some(ids) = state.channel_connection_ids(channel_id) { connection_ids = ids; @@ -829,8 +810,7 @@ impl Server { request: TypedEnvelope, ) -> tide::Result<()> { let user_id = self - .store - .read() + .state() .await .user_id_for_connection(request.sender_id)?; let channel_id = ChannelId::from_proto(request.payload.channel_id); @@ -872,6 +852,18 @@ impl Server { .await?; Ok(()) } + + fn state<'a>( + self: &'a Arc, + ) -> impl Future> { + self.store.read() + } + + fn state_mut<'a>( + self: &'a mut Arc, + ) -> impl Future> { + self.store.write() + } } pub async fn broadcast(