Prevent the same user from being called more than once

Antonio Scandurra created

Change summary

crates/collab/src/rpc/store.rs | 83 +++++++++++++++++++++++++++--------
1 file changed, 64 insertions(+), 19 deletions(-)

Detailed changes

crates/collab/src/rpc/store.rs 🔗

@@ -13,7 +13,7 @@ pub type RoomId = u64;
 #[derive(Default, Serialize)]
 pub struct Store {
     connections: BTreeMap<ConnectionId, ConnectionState>,
-    connections_by_user_id: BTreeMap<UserId, HashSet<ConnectionId>>,
+    connections_by_user_id: BTreeMap<UserId, UserConnectionState>,
     next_room_id: RoomId,
     rooms: BTreeMap<RoomId, proto::Room>,
     projects: BTreeMap<ProjectId, Project>,
@@ -21,16 +21,27 @@ pub struct Store {
     channels: BTreeMap<ChannelId, Channel>,
 }
 
+#[derive(Default, Serialize)]
+struct UserConnectionState {
+    connection_ids: HashSet<ConnectionId>,
+    room: Option<RoomState>,
+}
+
 #[derive(Serialize)]
 struct ConnectionState {
     user_id: UserId,
     admin: bool,
-    room: Option<RoomId>,
     projects: BTreeSet<ProjectId>,
     requested_projects: HashSet<ProjectId>,
     channels: HashSet<ChannelId>,
 }
 
+#[derive(Copy, Clone, Eq, PartialEq, Serialize)]
+enum RoomState {
+    Joined,
+    Calling { room_id: RoomId },
+}
+
 #[derive(Serialize)]
 pub struct Project {
     pub online: bool,
@@ -140,7 +151,6 @@ impl Store {
             ConnectionState {
                 user_id,
                 admin,
-                room: Default::default(),
                 projects: Default::default(),
                 requested_projects: Default::default(),
                 channels: Default::default(),
@@ -149,6 +159,7 @@ impl Store {
         self.connections_by_user_id
             .entry(user_id)
             .or_default()
+            .connection_ids
             .insert(connection_id);
     }
 
@@ -185,9 +196,9 @@ impl Store {
             }
         }
 
-        let user_connections = self.connections_by_user_id.get_mut(&user_id).unwrap();
-        user_connections.remove(&connection_id);
-        if user_connections.is_empty() {
+        let user_connection_state = self.connections_by_user_id.get_mut(&user_id).unwrap();
+        user_connection_state.connection_ids.remove(&connection_id);
+        if user_connection_state.connection_ids.is_empty() {
             self.connections_by_user_id.remove(&user_id);
         }
 
@@ -239,6 +250,7 @@ impl Store {
         self.connections_by_user_id
             .get(&user_id)
             .into_iter()
+            .map(|state| &state.connection_ids)
             .flatten()
             .copied()
     }
@@ -248,6 +260,7 @@ impl Store {
             .connections_by_user_id
             .get(&user_id)
             .unwrap_or(&Default::default())
+            .connection_ids
             .is_empty()
     }
 
@@ -295,9 +308,10 @@ impl Store {
     }
 
     pub fn project_metadata_for_user(&self, user_id: UserId) -> Vec<proto::ProjectMetadata> {
-        let connection_ids = self.connections_by_user_id.get(&user_id);
-        let project_ids = connection_ids.iter().flat_map(|connection_ids| {
-            connection_ids
+        let user_connection_state = self.connections_by_user_id.get(&user_id);
+        let project_ids = user_connection_state.iter().flat_map(|state| {
+            state
+                .connection_ids
                 .iter()
                 .filter_map(|connection_id| self.connections.get(connection_id))
                 .flat_map(|connection| connection.projects.iter().copied())
@@ -333,8 +347,12 @@ impl Store {
             .connections
             .get_mut(&creator_connection_id)
             .ok_or_else(|| anyhow!("no such connection"))?;
+        let user_connection_state = self
+            .connections_by_user_id
+            .get_mut(&connection.user_id)
+            .ok_or_else(|| anyhow!("no such connection"))?;
         anyhow::ensure!(
-            connection.room.is_none(),
+            user_connection_state.room.is_none(),
             "cannot participate in more than one room at once"
         );
 
@@ -352,7 +370,7 @@ impl Store {
 
         let room_id = post_inc(&mut self.next_room_id);
         self.rooms.insert(room_id, room);
-        connection.room = Some(room_id);
+        user_connection_state.room = Some(RoomState::Joined);
         Ok(room_id)
     }
 
@@ -365,14 +383,20 @@ impl Store {
             .connections
             .get_mut(&connection_id)
             .ok_or_else(|| anyhow!("no such connection"))?;
+        let user_id = connection.user_id;
+        let recipient_ids = self.connection_ids_for_user(user_id).collect::<Vec<_>>();
+
+        let mut user_connection_state = self
+            .connections_by_user_id
+            .get_mut(&user_id)
+            .ok_or_else(|| anyhow!("no such connection"))?;
         anyhow::ensure!(
-            connection.room.is_none(),
+            user_connection_state
+                .room
+                .map_or(true, |room| room == RoomState::Calling { room_id }),
             "cannot participate in more than one room at once"
         );
 
-        let user_id = connection.user_id;
-        let recipient_ids = self.connection_ids_for_user(user_id).collect::<Vec<_>>();
-
         let room = self
             .rooms
             .get_mut(&room_id)
@@ -393,6 +417,7 @@ impl Store {
                 )),
             }),
         });
+        user_connection_state.room = Some(RoomState::Joined);
 
         Ok((room, recipient_ids))
     }
@@ -404,7 +429,17 @@ impl Store {
         to_user_id: UserId,
     ) -> Result<(UserId, Vec<ConnectionId>, &proto::Room)> {
         let from_user_id = self.user_id_for_connection(from_connection_id)?;
+
         let to_connection_ids = self.connection_ids_for_user(to_user_id).collect::<Vec<_>>();
+        let mut to_user_connection_state = self
+            .connections_by_user_id
+            .get_mut(&to_user_id)
+            .ok_or_else(|| anyhow!("no such connection"))?;
+        anyhow::ensure!(
+            to_user_connection_state.room.is_none(),
+            "recipient is already on another call"
+        );
+
         let room = self
             .rooms
             .get_mut(&room_id)
@@ -422,11 +457,18 @@ impl Store {
             "cannot call the same user more than once"
         );
         room.pending_calls_to_user_ids.push(to_user_id.to_proto());
+        to_user_connection_state.room = Some(RoomState::Calling { room_id });
 
         Ok((from_user_id, to_connection_ids, room))
     }
 
     pub fn call_failed(&mut self, room_id: RoomId, to_user_id: UserId) -> Result<&proto::Room> {
+        let mut to_user_connection_state = self
+            .connections_by_user_id
+            .get_mut(&to_user_id)
+            .ok_or_else(|| anyhow!("no such connection"))?;
+        anyhow::ensure!(to_user_connection_state.room == Some(RoomState::Calling { room_id }));
+        to_user_connection_state.room = None;
         let room = self
             .rooms
             .get_mut(&room_id)
@@ -548,10 +590,12 @@ impl Store {
                     }
 
                     for requester_user_id in project.join_requests.keys() {
-                        if let Some(requester_connection_ids) =
+                        if let Some(requester_user_connection_state) =
                             self.connections_by_user_id.get_mut(requester_user_id)
                         {
-                            for requester_connection_id in requester_connection_ids.iter() {
+                            for requester_connection_id in
+                                &requester_user_connection_state.connection_ids
+                            {
                                 if let Some(requester_connection) =
                                     self.connections.get_mut(requester_connection_id)
                                 {
@@ -907,11 +951,12 @@ impl Store {
                 .connections_by_user_id
                 .get(&connection.user_id)
                 .unwrap()
+                .connection_ids
                 .contains(connection_id));
         }
 
-        for (user_id, connection_ids) in &self.connections_by_user_id {
-            for connection_id in connection_ids {
+        for (user_id, state) in &self.connections_by_user_id {
+            for connection_id in &state.connection_ids {
                 assert_eq!(
                     self.connections.get(connection_id).unwrap().user_id,
                     *user_id