Correctly leave projects when leaving room

Antonio Scandurra created

Change summary

crates/collab/src/db.rs  | 108 +++++++++++++++++++++++++++-------------
crates/collab/src/rpc.rs |  72 +++++++++++++--------------
2 files changed, 107 insertions(+), 73 deletions(-)

Detailed changes

crates/collab/src/db.rs 🔗

@@ -1171,44 +1171,68 @@ where
                 .fetch_all(&mut tx)
                 .await?;
 
-                let mut project_collaborators = sqlx::query_as::<_, ProjectCollaborator>(
+                let project_ids = sqlx::query_scalar::<_, ProjectId>(
                     "
-                    SELECT project_collaborators.*
-                    FROM projects, project_collaborators
-                    WHERE
-                        projects.room_id = $1 AND
-                        projects.id = project_collaborators.project_id AND
-                        project_collaborators.connection_id = $2
+                    SELECT project_id
+                    FROM project_collaborators
+                    WHERE connection_id = $1
                     ",
                 )
-                .bind(room_id)
                 .bind(connection_id.0 as i32)
-                .fetch(&mut tx);
+                .fetch_all(&mut tx)
+                .await?;
 
+                // Leave projects.
                 let mut left_projects = HashMap::default();
-                while let Some(collaborator) = project_collaborators.next().await {
-                    let collaborator = collaborator?;
-                    let left_project =
-                        left_projects
-                            .entry(collaborator.project_id)
-                            .or_insert(LeftProject {
-                                id: collaborator.project_id,
-                                host_user_id: Default::default(),
-                                connection_ids: Default::default(),
-                            });
-
-                    let collaborator_connection_id =
-                        ConnectionId(collaborator.connection_id as u32);
-                    if collaborator_connection_id != connection_id || collaborator.is_host {
-                        left_project.connection_ids.push(collaborator_connection_id);
+                if !project_ids.is_empty() {
+                    let mut params = "?,".repeat(project_ids.len());
+                    params.pop();
+                    let query = format!(
+                        "
+                        SELECT *
+                        FROM project_collaborators
+                        WHERE project_id IN ({params})
+                    "
+                    );
+                    let mut query = sqlx::query_as::<_, ProjectCollaborator>(&query);
+                    for project_id in project_ids {
+                        query = query.bind(project_id);
                     }
 
-                    if collaborator.is_host {
-                        left_project.host_user_id = collaborator.user_id;
+                    let mut project_collaborators = query.fetch(&mut tx);
+                    while let Some(collaborator) = project_collaborators.next().await {
+                        let collaborator = collaborator?;
+                        let left_project =
+                            left_projects
+                                .entry(collaborator.project_id)
+                                .or_insert(LeftProject {
+                                    id: collaborator.project_id,
+                                    host_user_id: Default::default(),
+                                    connection_ids: Default::default(),
+                                });
+
+                        let collaborator_connection_id =
+                            ConnectionId(collaborator.connection_id as u32);
+                        if collaborator_connection_id != connection_id {
+                            left_project.connection_ids.push(collaborator_connection_id);
+                        }
+
+                        if collaborator.is_host {
+                            left_project.host_user_id = collaborator.user_id;
+                        }
                     }
                 }
-                drop(project_collaborators);
+                sqlx::query(
+                    "
+                    DELETE FROM project_collaborators
+                    WHERE connection_id = $1
+                    ",
+                )
+                .bind(connection_id.0 as i32)
+                .execute(&mut tx)
+                .await?;
 
+                // Unshare projects.
                 sqlx::query(
                     "
                     DELETE FROM projects
@@ -1265,15 +1289,16 @@ where
             sqlx::query(
                 "
                 UPDATE room_participants
-                SET location_kind = $1 AND location_project_id = $2
+                SET location_kind = $1, location_project_id = $2
                 WHERE room_id = $3 AND answering_connection_id = $4
+                RETURNING 1
                 ",
             )
             .bind(location_kind)
             .bind(location_project_id)
             .bind(room_id)
             .bind(connection_id.0 as i32)
-            .execute(&mut tx)
+            .fetch_one(&mut tx)
             .await?;
 
             self.commit_room_transaction(room_id, tx).await
@@ -1335,21 +1360,32 @@ where
             let (
                 user_id,
                 answering_connection_id,
-                _location_kind,
-                _location_project_id,
+                location_kind,
+                location_project_id,
                 calling_user_id,
                 initial_project_id,
             ) = participant?;
             if let Some(answering_connection_id) = answering_connection_id {
+                let location = match (location_kind, location_project_id) {
+                    (Some(0), Some(project_id)) => {
+                        Some(proto::participant_location::Variant::SharedProject(
+                            proto::participant_location::SharedProject {
+                                id: project_id.to_proto(),
+                            },
+                        ))
+                    }
+                    (Some(1), _) => Some(proto::participant_location::Variant::UnsharedProject(
+                        Default::default(),
+                    )),
+                    _ => Some(proto::participant_location::Variant::External(
+                        Default::default(),
+                    )),
+                };
                 participants.push(proto::Participant {
                     user_id: user_id.to_proto(),
                     peer_id: answering_connection_id as u32,
                     projects: Default::default(),
-                    location: Some(proto::ParticipantLocation {
-                        variant: Some(proto::participant_location::Variant::External(
-                            Default::default(),
-                        )),
-                    }),
+                    location: Some(proto::ParticipantLocation { variant: location }),
                 });
             } else {
                 pending_participants.push(proto::PendingParticipant {

crates/collab/src/rpc.rs 🔗

@@ -624,19 +624,19 @@ impl Server {
 
     async fn leave_room_for_connection(
         self: &Arc<Server>,
-        connection_id: ConnectionId,
-        user_id: UserId,
+        leaving_connection_id: ConnectionId,
+        leaving_user_id: UserId,
     ) -> Result<()> {
         let mut contacts_to_update = HashSet::default();
 
-        let Some(left_room) = self.app_state.db.leave_room_for_connection(connection_id).await? else {
+        let Some(left_room) = self.app_state.db.leave_room_for_connection(leaving_connection_id).await? else {
             return Err(anyhow!("no room to leave"))?;
         };
-        contacts_to_update.insert(user_id);
+        contacts_to_update.insert(leaving_user_id);
 
         for project in left_room.left_projects.into_values() {
-            if project.host_user_id == user_id {
-                for connection_id in project.connection_ids {
+            for connection_id in project.connection_ids {
+                if project.host_user_id == leaving_user_id {
                     self.peer
                         .send(
                             connection_id,
@@ -645,29 +645,27 @@ impl Server {
                             },
                         )
                         .trace_err();
-                }
-            } else {
-                for connection_id in project.connection_ids {
+                } else {
                     self.peer
                         .send(
                             connection_id,
                             proto::RemoveProjectCollaborator {
                                 project_id: project.id.to_proto(),
-                                peer_id: connection_id.0,
+                                peer_id: leaving_connection_id.0,
                             },
                         )
                         .trace_err();
                 }
-
-                self.peer
-                    .send(
-                        connection_id,
-                        proto::UnshareProject {
-                            project_id: project.id.to_proto(),
-                        },
-                    )
-                    .trace_err();
             }
+
+            self.peer
+                .send(
+                    leaving_connection_id,
+                    proto::UnshareProject {
+                        project_id: project.id.to_proto(),
+                    },
+                )
+                .trace_err();
         }
 
         self.room_updated(&left_room.room);
@@ -691,7 +689,7 @@ impl Server {
             live_kit
                 .remove_participant(
                     left_room.room.live_kit_room.clone(),
-                    connection_id.to_string(),
+                    leaving_connection_id.to_string(),
                 )
                 .await
                 .trace_err();
@@ -941,6 +939,9 @@ impl Server {
         let collaborators = project
             .collaborators
             .iter()
+            .filter(|collaborator| {
+                collaborator.connection_id != request.sender_connection_id.0 as i32
+            })
             .map(|collaborator| proto::Collaborator {
                 peer_id: collaborator.connection_id as u32,
                 replica_id: collaborator.replica_id.0 as u32,
@@ -958,23 +959,20 @@ impl Server {
             })
             .collect::<Vec<_>>();
 
-        for collaborator in &project.collaborators {
-            let connection_id = ConnectionId(collaborator.connection_id as u32);
-            if connection_id != request.sender_connection_id {
-                self.peer
-                    .send(
-                        connection_id,
-                        proto::AddProjectCollaborator {
-                            project_id: project_id.to_proto(),
-                            collaborator: Some(proto::Collaborator {
-                                peer_id: request.sender_connection_id.0,
-                                replica_id: replica_id.0 as u32,
-                                user_id: guest_user_id.to_proto(),
-                            }),
-                        },
-                    )
-                    .trace_err();
-            }
+        for collaborator in &collaborators {
+            self.peer
+                .send(
+                    ConnectionId(collaborator.peer_id),
+                    proto::AddProjectCollaborator {
+                        project_id: project_id.to_proto(),
+                        collaborator: Some(proto::Collaborator {
+                            peer_id: request.sender_connection_id.0,
+                            replica_id: replica_id.0 as u32,
+                            user_id: guest_user_id.to_proto(),
+                        }),
+                    },
+                )
+                .trace_err();
         }
 
         // First, we send the metadata associated with each worktree.