From f4697ff4d14300b9a18619b1a113faa2cac2e8c7 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Mon, 26 Sep 2022 11:13:34 +0200 Subject: [PATCH] Prevent the same user from being called more than once --- crates/collab/src/rpc/store.rs | 83 ++++++++++++++++++++++++++-------- 1 file changed, 64 insertions(+), 19 deletions(-) diff --git a/crates/collab/src/rpc/store.rs b/crates/collab/src/rpc/store.rs index 6b756918021f9771e08e00ded5f033ae5ad9a9ec..d19ae122e095a048c66b58893014f4f76d87bce6 100644 --- a/crates/collab/src/rpc/store.rs +++ b/crates/collab/src/rpc/store.rs @@ -13,7 +13,7 @@ pub type RoomId = u64; #[derive(Default, Serialize)] pub struct Store { connections: BTreeMap, - connections_by_user_id: BTreeMap>, + connections_by_user_id: BTreeMap, next_room_id: RoomId, rooms: BTreeMap, projects: BTreeMap, @@ -21,16 +21,27 @@ pub struct Store { channels: BTreeMap, } +#[derive(Default, Serialize)] +struct UserConnectionState { + connection_ids: HashSet, + room: Option, +} + #[derive(Serialize)] struct ConnectionState { user_id: UserId, admin: bool, - room: Option, projects: BTreeSet, requested_projects: HashSet, channels: HashSet, } +#[derive(Copy, Clone, Eq, PartialEq, Serialize)] +enum RoomState { + Joined, + Calling { room_id: RoomId }, +} + #[derive(Serialize)] pub struct Project { pub online: bool, @@ -140,7 +151,6 @@ impl Store { ConnectionState { user_id, admin, - room: Default::default(), projects: Default::default(), requested_projects: Default::default(), channels: Default::default(), @@ -149,6 +159,7 @@ impl Store { self.connections_by_user_id .entry(user_id) .or_default() + .connection_ids .insert(connection_id); } @@ -185,9 +196,9 @@ impl Store { } } - let user_connections = self.connections_by_user_id.get_mut(&user_id).unwrap(); - user_connections.remove(&connection_id); - if user_connections.is_empty() { + let user_connection_state = self.connections_by_user_id.get_mut(&user_id).unwrap(); + user_connection_state.connection_ids.remove(&connection_id); + if user_connection_state.connection_ids.is_empty() { self.connections_by_user_id.remove(&user_id); } @@ -239,6 +250,7 @@ impl Store { self.connections_by_user_id .get(&user_id) .into_iter() + .map(|state| &state.connection_ids) .flatten() .copied() } @@ -248,6 +260,7 @@ impl Store { .connections_by_user_id .get(&user_id) .unwrap_or(&Default::default()) + .connection_ids .is_empty() } @@ -295,9 +308,10 @@ impl Store { } pub fn project_metadata_for_user(&self, user_id: UserId) -> Vec { - let connection_ids = self.connections_by_user_id.get(&user_id); - let project_ids = connection_ids.iter().flat_map(|connection_ids| { - connection_ids + let user_connection_state = self.connections_by_user_id.get(&user_id); + let project_ids = user_connection_state.iter().flat_map(|state| { + state + .connection_ids .iter() .filter_map(|connection_id| self.connections.get(connection_id)) .flat_map(|connection| connection.projects.iter().copied()) @@ -333,8 +347,12 @@ impl Store { .connections .get_mut(&creator_connection_id) .ok_or_else(|| anyhow!("no such connection"))?; + let user_connection_state = self + .connections_by_user_id + .get_mut(&connection.user_id) + .ok_or_else(|| anyhow!("no such connection"))?; anyhow::ensure!( - connection.room.is_none(), + user_connection_state.room.is_none(), "cannot participate in more than one room at once" ); @@ -352,7 +370,7 @@ impl Store { let room_id = post_inc(&mut self.next_room_id); self.rooms.insert(room_id, room); - connection.room = Some(room_id); + user_connection_state.room = Some(RoomState::Joined); Ok(room_id) } @@ -365,14 +383,20 @@ impl Store { .connections .get_mut(&connection_id) .ok_or_else(|| anyhow!("no such connection"))?; + let user_id = connection.user_id; + let recipient_ids = self.connection_ids_for_user(user_id).collect::>(); + + let mut user_connection_state = self + .connections_by_user_id + .get_mut(&user_id) + .ok_or_else(|| anyhow!("no such connection"))?; anyhow::ensure!( - connection.room.is_none(), + user_connection_state + .room + .map_or(true, |room| room == RoomState::Calling { room_id }), "cannot participate in more than one room at once" ); - let user_id = connection.user_id; - let recipient_ids = self.connection_ids_for_user(user_id).collect::>(); - let room = self .rooms .get_mut(&room_id) @@ -393,6 +417,7 @@ impl Store { )), }), }); + user_connection_state.room = Some(RoomState::Joined); Ok((room, recipient_ids)) } @@ -404,7 +429,17 @@ impl Store { to_user_id: UserId, ) -> Result<(UserId, Vec, &proto::Room)> { let from_user_id = self.user_id_for_connection(from_connection_id)?; + let to_connection_ids = self.connection_ids_for_user(to_user_id).collect::>(); + let mut to_user_connection_state = self + .connections_by_user_id + .get_mut(&to_user_id) + .ok_or_else(|| anyhow!("no such connection"))?; + anyhow::ensure!( + to_user_connection_state.room.is_none(), + "recipient is already on another call" + ); + let room = self .rooms .get_mut(&room_id) @@ -422,11 +457,18 @@ impl Store { "cannot call the same user more than once" ); room.pending_calls_to_user_ids.push(to_user_id.to_proto()); + to_user_connection_state.room = Some(RoomState::Calling { room_id }); Ok((from_user_id, to_connection_ids, room)) } pub fn call_failed(&mut self, room_id: RoomId, to_user_id: UserId) -> Result<&proto::Room> { + let mut to_user_connection_state = self + .connections_by_user_id + .get_mut(&to_user_id) + .ok_or_else(|| anyhow!("no such connection"))?; + anyhow::ensure!(to_user_connection_state.room == Some(RoomState::Calling { room_id })); + to_user_connection_state.room = None; let room = self .rooms .get_mut(&room_id) @@ -548,10 +590,12 @@ impl Store { } for requester_user_id in project.join_requests.keys() { - if let Some(requester_connection_ids) = + if let Some(requester_user_connection_state) = self.connections_by_user_id.get_mut(requester_user_id) { - for requester_connection_id in requester_connection_ids.iter() { + for requester_connection_id in + &requester_user_connection_state.connection_ids + { if let Some(requester_connection) = self.connections.get_mut(requester_connection_id) { @@ -907,11 +951,12 @@ impl Store { .connections_by_user_id .get(&connection.user_id) .unwrap() + .connection_ids .contains(connection_id)); } - for (user_id, connection_ids) in &self.connections_by_user_id { - for connection_id in connection_ids { + for (user_id, state) in &self.connections_by_user_id { + for connection_id in &state.connection_ids { assert_eq!( self.connections.get(connection_id).unwrap().user_id, *user_id