Leave room when connection is dropped

Antonio Scandurra created

Change summary

crates/call/src/room.rs                                          |   4 
crates/collab/migrations.sqlite/20221109000000_test_schema.sql   |   5 
crates/collab/migrations/20221111092550_reconnection_support.sql |   3 
crates/collab/src/db.rs                                          | 189 
crates/collab/src/rpc.rs                                         | 202 -
crates/collab/src/rpc/store.rs                                   |  10 
crates/rpc/proto/zed.proto                                       |   4 
7 files changed, 183 insertions(+), 234 deletions(-)

Detailed changes

crates/call/src/room.rs 🔗

@@ -53,7 +53,7 @@ impl Entity for Room {
 
     fn release(&mut self, _: &mut MutableAppContext) {
         if self.status.is_online() {
-            self.client.send(proto::LeaveRoom { id: self.id }).log_err();
+            self.client.send(proto::LeaveRoom {}).log_err();
         }
     }
 }
@@ -241,7 +241,7 @@ impl Room {
         self.participant_user_ids.clear();
         self.subscriptions.clear();
         self.live_kit.take();
-        self.client.send(proto::LeaveRoom { id: self.id })?;
+        self.client.send(proto::LeaveRoom {})?;
         Ok(())
     }
 

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -43,7 +43,8 @@ CREATE TABLE "rooms" (
 CREATE TABLE "projects" (
     "id" INTEGER PRIMARY KEY,
     "room_id" INTEGER REFERENCES rooms (id),
-    "host_user_id" INTEGER REFERENCES users (id) NOT NULL
+    "host_user_id" INTEGER REFERENCES users (id) NOT NULL,
+    "host_connection_id" INTEGER NOT NULL
 );
 
 CREATE TABLE "project_collaborators" (
@@ -72,7 +73,7 @@ CREATE TABLE "room_participants" (
     "location_kind" INTEGER,
     "location_project_id" INTEGER REFERENCES projects (id),
     "initial_project_id" INTEGER REFERENCES projects (id),
-    "calling_user_id" INTEGER NOT NULL REFERENCES users (id)
+    "calling_user_id" INTEGER NOT NULL REFERENCES users (id),
     "calling_connection_id" INTEGER NOT NULL
 );
 CREATE UNIQUE INDEX "index_room_participants_on_user_id" ON "room_participants" ("user_id");

crates/collab/migrations/20221111092550_reconnection_support.sql 🔗

@@ -6,6 +6,7 @@ CREATE TABLE IF NOT EXISTS "rooms" (
 
 ALTER TABLE "projects"
     ADD "room_id" INTEGER REFERENCES rooms (id),
+    ADD "host_connection_id" INTEGER,
     DROP COLUMN "unregistered";
 
 CREATE TABLE "project_collaborators" (
@@ -30,7 +31,7 @@ CREATE TABLE IF NOT EXISTS "room_participants" (
     "id" SERIAL PRIMARY KEY,
     "room_id" INTEGER NOT NULL REFERENCES rooms (id),
     "user_id" INTEGER NOT NULL REFERENCES users (id),
-    "connection_id" INTEGER,
+    "answering_connection_id" INTEGER,
     "location_kind" INTEGER,
     "location_project_id" INTEGER REFERENCES projects (id),
     "initial_project_id" INTEGER REFERENCES projects (id),

crates/collab/src/db.rs 🔗

@@ -907,14 +907,15 @@ where
 
             sqlx::query(
                 "
-                INSERT INTO room_participants (room_id, user_id, connection_id, calling_user_id)
-                VALUES ($1, $2, $3, $4)
+                INSERT INTO room_participants (room_id, user_id, answering_connection_id, calling_user_id, calling_connection_id)
+                VALUES ($1, $2, $3, $4, $5)
                 ",
             )
             .bind(room_id)
             .bind(user_id)
             .bind(connection_id.0 as i32)
             .bind(user_id)
+            .bind(connection_id.0 as i32)
             .execute(&mut tx)
             .await?;
 
@@ -926,6 +927,7 @@ where
         &self,
         room_id: RoomId,
         calling_user_id: UserId,
+        calling_connection_id: ConnectionId,
         called_user_id: UserId,
         initial_project_id: Option<ProjectId>,
     ) -> Result<(proto::Room, proto::IncomingCall)> {
@@ -933,13 +935,14 @@ where
             let mut tx = self.pool.begin().await?;
             sqlx::query(
                 "
-                INSERT INTO room_participants (room_id, user_id, calling_user_id, initial_project_id)
-                VALUES ($1, $2, $3, $4)
+                INSERT INTO room_participants (room_id, user_id, calling_user_id, calling_connection_id, initial_project_id)
+                VALUES ($1, $2, $3, $4, $5)
                 ",
             )
             .bind(room_id)
             .bind(called_user_id)
             .bind(calling_user_id)
+            .bind(calling_connection_id.0 as i32)
             .bind(initial_project_id)
             .execute(&mut tx)
             .await?;
@@ -961,7 +964,7 @@ where
                 "
                 SELECT room_id
                 FROM room_participants
-                WHERE user_id = $1 AND connection_id IS NULL
+                WHERE user_id = $1 AND answering_connection_id IS NULL
                 ",
             )
             .bind(user_id)
@@ -1033,7 +1036,7 @@ where
             sqlx::query(
                 "
                 DELETE FROM room_participants
-                WHERE room_id = $1 AND user_id = $2 AND connection_id IS NULL
+                WHERE room_id = $1 AND user_id = $2 AND answering_connection_id IS NULL
                 ",
             )
             .bind(room_id)
@@ -1056,7 +1059,7 @@ where
             sqlx::query(
                 "
                 UPDATE room_participants 
-                SET connection_id = $1
+                SET answering_connection_id = $1
                 WHERE room_id = $2 AND user_id = $3
                 RETURNING 1
                 ",
@@ -1070,101 +1073,100 @@ where
         })
     }
 
-    pub async fn leave_room(
-        &self,
-        room_id: RoomId,
-        connection_id: ConnectionId,
-    ) -> Result<LeftRoom> {
+    pub async fn leave_room(&self, connection_id: ConnectionId) -> Result<Option<LeftRoom>> {
         test_support!(self, {
             let mut tx = self.pool.begin().await?;
 
             // Leave room.
-            let user_id: UserId = sqlx::query_scalar(
-                "
-                DELETE FROM room_participants
-                WHERE room_id = $1 AND connection_id = $2
-                RETURNING user_id
-                ",
-            )
-            .bind(room_id)
-            .bind(connection_id.0 as i32)
-            .fetch_one(&mut tx)
-            .await?;
-
-            // Cancel pending calls initiated by the leaving user.
-            let canceled_calls_to_user_ids: Vec<UserId> = sqlx::query_scalar(
+            let room_id = sqlx::query_scalar::<_, RoomId>(
                 "
                 DELETE FROM room_participants
-                WHERE calling_user_id = $1 AND connection_id IS NULL
-                RETURNING user_id
+                WHERE answering_connection_id = $1
+                RETURNING room_id
                 ",
             )
-            .bind(room_id)
             .bind(connection_id.0 as i32)
-            .fetch_all(&mut tx)
+            .fetch_optional(&mut tx)
             .await?;
 
-            let mut project_collaborators = sqlx::query_as::<_, ProjectCollaborator>(
-                "
-                SELECT project_collaborators.*
-                FROM projects, project_collaborators
-                WHERE
-                    projects.room_id = $1 AND
-                    projects.host_user_id = $2 AND
-                    projects.id = project_collaborators.project_id
-                ",
-            )
-            .bind(room_id)
-            .bind(user_id)
-            .fetch(&mut tx);
-
-            let mut left_projects = HashMap::default();
-            while let Some(collaborator) = project_collaborators.next().await {
-                let collaborator = collaborator?;
-                let left_project =
-                    left_projects
-                        .entry(collaborator.project_id)
-                        .or_insert(LeftProject {
-                            id: collaborator.project_id,
-                            host_user_id: Default::default(),
-                            connection_ids: Default::default(),
-                        });
+            if let Some(room_id) = room_id {
+                // Cancel pending calls initiated by the leaving user.
+                let canceled_calls_to_user_ids: Vec<UserId> = sqlx::query_scalar(
+                    "
+                    DELETE FROM room_participants
+                    WHERE calling_connection_id = $1 AND answering_connection_id IS NULL
+                    RETURNING user_id
+                    ",
+                )
+                .bind(connection_id.0 as i32)
+                .fetch_all(&mut tx)
+                .await?;
 
-                let collaborator_connection_id = ConnectionId(collaborator.connection_id as u32);
-                if collaborator_connection_id != connection_id || collaborator.is_host {
-                    left_project.connection_ids.push(collaborator_connection_id);
-                }
+                let mut project_collaborators = sqlx::query_as::<_, ProjectCollaborator>(
+                    "
+                    SELECT project_collaborators.*
+                    FROM projects, project_collaborators
+                    WHERE
+                        projects.room_id = $1 AND
+                        projects.host_connection_id = $2 AND
+                        projects.id = project_collaborators.project_id
+                    ",
+                )
+                .bind(room_id)
+                .bind(connection_id.0 as i32)
+                .fetch(&mut tx);
+
+                let mut left_projects = HashMap::default();
+                while let Some(collaborator) = project_collaborators.next().await {
+                    let collaborator = collaborator?;
+                    let left_project =
+                        left_projects
+                            .entry(collaborator.project_id)
+                            .or_insert(LeftProject {
+                                id: collaborator.project_id,
+                                host_user_id: Default::default(),
+                                connection_ids: Default::default(),
+                            });
+
+                    let collaborator_connection_id =
+                        ConnectionId(collaborator.connection_id as u32);
+                    if collaborator_connection_id != connection_id || collaborator.is_host {
+                        left_project.connection_ids.push(collaborator_connection_id);
+                    }
 
-                if collaborator.is_host {
-                    left_project.host_user_id = collaborator.user_id;
+                    if collaborator.is_host {
+                        left_project.host_user_id = collaborator.user_id;
+                    }
                 }
-            }
-            drop(project_collaborators);
+                drop(project_collaborators);
 
-            sqlx::query(
-                "
-                DELETE FROM projects
-                WHERE room_id = $1 AND host_user_id = $2
-                ",
-            )
-            .bind(room_id)
-            .bind(user_id)
-            .execute(&mut tx)
-            .await?;
+                sqlx::query(
+                    "
+                    DELETE FROM projects
+                    WHERE room_id = $1 AND host_connection_id = $2
+                    ",
+                )
+                .bind(room_id)
+                .bind(connection_id.0 as i32)
+                .execute(&mut tx)
+                .await?;
 
-            let room = self.commit_room_transaction(room_id, tx).await?;
-            Ok(LeftRoom {
-                room,
-                left_projects,
-                canceled_calls_to_user_ids,
-            })
+                let room = self.commit_room_transaction(room_id, tx).await?;
+                Ok(Some(LeftRoom {
+                    room,
+                    left_projects,
+                    canceled_calls_to_user_ids,
+                }))
+            } else {
+                Ok(None)
+            }
         })
     }
 
     pub async fn update_room_participant_location(
         &self,
         room_id: RoomId,
-        user_id: UserId,
+        connection_id: ConnectionId,
         location: proto::ParticipantLocation,
     ) -> Result<proto::Room> {
         test_support!(self, {
@@ -1194,13 +1196,13 @@ where
                 "
                 UPDATE room_participants
                 SET location_kind = $1 AND location_project_id = $2
-                WHERE room_id = $1 AND user_id = $2
+                WHERE room_id = $3 AND answering_connection_id = $4
                 ",
             )
             .bind(location_kind)
             .bind(location_project_id)
             .bind(room_id)
-            .bind(user_id)
+            .bind(connection_id.0 as i32)
             .execute(&mut tx)
             .await?;
 
@@ -1248,7 +1250,7 @@ where
         let mut db_participants =
             sqlx::query_as::<_, (UserId, Option<i32>, Option<i32>, Option<ProjectId>, UserId, Option<ProjectId>)>(
                 "
-                SELECT user_id, connection_id, location_kind, location_project_id, calling_user_id, initial_project_id
+                SELECT user_id, answering_connection_id, location_kind, location_project_id, calling_user_id, initial_project_id
                 FROM room_participants
                 WHERE room_id = $1
                 ",
@@ -1261,16 +1263,16 @@ where
         while let Some(participant) = db_participants.next().await {
             let (
                 user_id,
-                connection_id,
+                answering_connection_id,
                 _location_kind,
                 _location_project_id,
                 calling_user_id,
                 initial_project_id,
             ) = participant?;
-            if let Some(connection_id) = connection_id {
+            if let Some(answering_connection_id) = answering_connection_id {
                 participants.push(proto::Participant {
                     user_id: user_id.to_proto(),
-                    peer_id: connection_id as u32,
+                    peer_id: answering_connection_id as u32,
                     projects: Default::default(),
                     location: Some(proto::ParticipantLocation {
                         variant: Some(proto::participant_location::Variant::External(
@@ -1339,12 +1341,13 @@ where
             let mut tx = self.pool.begin().await?;
             let project_id = sqlx::query_scalar(
                 "
-                INSERT INTO projects (host_user_id, room_id)
-                VALUES ($1)
+                INSERT INTO projects (host_user_id, host_connection_id, room_id)
+                VALUES ($1, $2, $3)
                 RETURNING id
                 ",
             )
             .bind(user_id)
+            .bind(connection_id.0 as i32)
             .bind(room_id)
             .fetch_one(&mut tx)
             .await
@@ -1366,11 +1369,11 @@ where
             sqlx::query(
                 "
                 INSERT INTO project_collaborators (
-                project_id,
-                connection_id,
-                user_id,
-                replica_id,
-                is_host
+                    project_id,
+                    connection_id,
+                    user_id,
+                    replica_id,
+                    is_host
                 )
                 VALUES ($1, $2, $3, $4, $5)
                 ",

crates/collab/src/rpc.rs 🔗

@@ -415,7 +415,7 @@ impl Server {
 
             drop(foreground_message_handlers);
             tracing::info!(%user_id, %login, %connection_id, %address, "signing out");
-            if let Err(error) = this.sign_out(connection_id).await {
+            if let Err(error) = this.sign_out(connection_id, user_id).await {
                 tracing::error!(%user_id, %login, %connection_id, %address, ?error, "error signing out");
             }
 
@@ -424,69 +424,15 @@ impl Server {
     }
 
     #[instrument(skip(self), err)]
-    async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> Result<()> {
+    async fn sign_out(
+        self: &mut Arc<Self>,
+        connection_id: ConnectionId,
+        user_id: UserId,
+    ) -> Result<()> {
         self.peer.disconnect(connection_id);
-
-        let mut projects_to_unshare = Vec::new();
-        let mut contacts_to_update = HashSet::default();
-        let mut room_left = None;
-        {
-            let removed_connection = self.store().await.remove_connection(connection_id)?;
-            self.app_state.db.remove_connection(connection_id);
-
-            for project in removed_connection.hosted_projects {
-                projects_to_unshare.push(project.id);
-                broadcast(connection_id, project.guests.keys().copied(), |conn_id| {
-                    self.peer.send(
-                        conn_id,
-                        proto::UnshareProject {
-                            project_id: project.id.to_proto(),
-                        },
-                    )
-                });
-            }
-
-            for project in removed_connection.guest_projects {
-                broadcast(connection_id, project.connection_ids, |conn_id| {
-                    self.peer.send(
-                        conn_id,
-                        proto::RemoveProjectCollaborator {
-                            project_id: project.id.to_proto(),
-                            peer_id: connection_id.0,
-                        },
-                    )
-                });
-            }
-
-            if let Some(room) = removed_connection.room {
-                self.room_updated(&room);
-                room_left = Some(self.room_left(&room, connection_id));
-            }
-
-            contacts_to_update.insert(removed_connection.user_id);
-            for connection_id in removed_connection.canceled_call_connection_ids {
-                self.peer
-                    .send(connection_id, proto::CallCanceled {})
-                    .trace_err();
-                contacts_to_update.extend(store.user_id_for_connection(connection_id).ok());
-            }
-        };
-
-        if let Some(room_left) = room_left {
-            room_left.await.trace_err();
-        }
-
-        for user_id in contacts_to_update {
-            self.update_user_contacts(user_id).await.trace_err();
-        }
-
-        for project_id in projects_to_unshare {
-            self.app_state
-                .db
-                .unshare_project(project_id)
-                .await
-                .trace_err();
-        }
+        self.store().await.remove_connection(connection_id)?;
+        self.leave_room_for_connection(connection_id, user_id)
+            .await?;
 
         Ok(())
     }
@@ -653,66 +599,90 @@ impl Server {
     }
 
     async fn leave_room(self: Arc<Server>, message: Message<proto::LeaveRoom>) -> Result<()> {
+        self.leave_room_for_connection(message.sender_connection_id, message.sender_user_id)
+            .await
+    }
+
+    async fn leave_room_for_connection(
+        self: &Arc<Server>,
+        connection_id: ConnectionId,
+        user_id: UserId,
+    ) -> Result<()> {
         let mut contacts_to_update = HashSet::default();
 
-        let left_room = self
-            .app_state
-            .db
-            .leave_room(
-                RoomId::from_proto(message.payload.id),
-                message.sender_connection_id,
-            )
-            .await?;
-        contacts_to_update.insert(message.sender_user_id);
+        let Some(left_room) = self.app_state.db.leave_room(connection_id).await? else {
+            return Err(anyhow!("no room to leave"))?;
+        };
+        contacts_to_update.insert(user_id);
 
         for project in left_room.left_projects.into_values() {
-            if project.host_user_id == message.sender_user_id {
+            if project.host_user_id == user_id {
                 for connection_id in project.connection_ids {
-                    self.peer.send(
-                        connection_id,
-                        proto::UnshareProject {
-                            project_id: project.id.to_proto(),
-                        },
-                    )?;
+                    self.peer
+                        .send(
+                            connection_id,
+                            proto::UnshareProject {
+                                project_id: project.id.to_proto(),
+                            },
+                        )
+                        .trace_err();
                 }
             } else {
                 for connection_id in project.connection_ids {
-                    self.peer.send(
+                    self.peer
+                        .send(
+                            connection_id,
+                            proto::RemoveProjectCollaborator {
+                                project_id: project.id.to_proto(),
+                                peer_id: connection_id.0,
+                            },
+                        )
+                        .trace_err();
+                }
+
+                self.peer
+                    .send(
                         connection_id,
-                        proto::RemoveProjectCollaborator {
+                        proto::UnshareProject {
                             project_id: project.id.to_proto(),
-                            peer_id: message.sender_connection_id.0,
                         },
-                    )?;
-                }
-
-                self.peer.send(
-                    message.sender_connection_id,
-                    proto::UnshareProject {
-                        project_id: project.id.to_proto(),
-                    },
-                )?;
+                    )
+                    .trace_err();
             }
         }
 
         self.room_updated(&left_room.room);
         {
             let store = self.store().await;
-            for user_id in left_room.canceled_calls_to_user_ids {
-                for connection_id in store.connection_ids_for_user(user_id) {
+            for canceled_user_id in left_room.canceled_calls_to_user_ids {
+                for connection_id in store.connection_ids_for_user(canceled_user_id) {
                     self.peer
                         .send(connection_id, proto::CallCanceled {})
                         .trace_err();
                 }
-                contacts_to_update.insert(user_id);
+                contacts_to_update.insert(canceled_user_id);
             }
         }
 
-        self.room_left(&left_room.room, message.sender_connection_id)
-            .await
-            .trace_err();
-        for user_id in contacts_to_update {
-            self.update_user_contacts(user_id).await?;
+        for contact_user_id in contacts_to_update {
+            self.update_user_contacts(contact_user_id).await?;
+        }
+
+        if let Some(live_kit) = self.app_state.live_kit_client.as_ref() {
+            live_kit
+                .remove_participant(
+                    left_room.room.live_kit_room.clone(),
+                    connection_id.to_string(),
+                )
+                .await
+                .trace_err();
+
+            if left_room.room.participants.is_empty() {
+                live_kit
+                    .delete_room(left_room.room.live_kit_room)
+                    .await
+                    .trace_err();
+            }
         }
 
         Ok(())
@@ -725,6 +695,7 @@ impl Server {
     ) -> Result<()> {
         let room_id = RoomId::from_proto(request.payload.room_id);
         let calling_user_id = request.sender_user_id;
+        let calling_connection_id = request.sender_connection_id;
         let called_user_id = UserId::from_proto(request.payload.called_user_id);
         let initial_project_id = request
             .payload
@@ -742,7 +713,13 @@ impl Server {
         let (room, incoming_call) = self
             .app_state
             .db
-            .call(room_id, calling_user_id, called_user_id, initial_project_id)
+            .call(
+                room_id,
+                calling_user_id,
+                calling_connection_id,
+                called_user_id,
+                initial_project_id,
+            )
             .await?;
         self.room_updated(&room);
         self.update_user_contacts(called_user_id).await?;
@@ -838,7 +815,7 @@ impl Server {
         let room = self
             .app_state
             .db
-            .update_room_participant_location(room_id, request.sender_user_id, location)
+            .update_room_participant_location(room_id, request.sender_connection_id, location)
             .await?;
         self.room_updated(&room);
         response.send(proto::Ack {})?;
@@ -858,29 +835,6 @@ impl Server {
         }
     }
 
-    fn room_left(
-        &self,
-        room: &proto::Room,
-        connection_id: ConnectionId,
-    ) -> impl Future<Output = Result<()>> {
-        let client = self.app_state.live_kit_client.clone();
-        let room_name = room.live_kit_room.clone();
-        let participant_count = room.participants.len();
-        async move {
-            if let Some(client) = client {
-                client
-                    .remove_participant(room_name.clone(), connection_id.to_string())
-                    .await?;
-
-                if participant_count == 0 {
-                    client.delete_room(room_name).await?;
-                }
-            }
-
-            Ok(())
-        }
-    }
-
     async fn share_project(
         self: Arc<Server>,
         request: Message<proto::ShareProject>,

crates/collab/src/rpc/store.rs 🔗

@@ -3,7 +3,7 @@ use anyhow::{anyhow, Result};
 use collections::{btree_map, BTreeMap, BTreeSet, HashMap, HashSet};
 use rpc::{proto, ConnectionId};
 use serde::Serialize;
-use std::{borrow::Cow, mem, path::PathBuf, str};
+use std::{mem, path::PathBuf, str};
 use tracing::instrument;
 
 pub type RoomId = u64;
@@ -135,14 +135,6 @@ impl Store {
         Ok(())
     }
 
-    pub fn user_id_for_connection(&self, connection_id: ConnectionId) -> Result<UserId> {
-        Ok(self
-            .connections
-            .get(&connection_id)
-            .ok_or_else(|| anyhow!("unknown connection"))?
-            .user_id)
-    }
-
     pub fn connection_ids_for_user(
         &self,
         user_id: UserId,

crates/rpc/proto/zed.proto 🔗

@@ -158,9 +158,7 @@ message JoinRoomResponse {
     optional LiveKitConnectionInfo live_kit_connection_info = 2;
 }
 
-message LeaveRoom {
-    uint64 id = 1;
-}
+message LeaveRoom {}
 
 message Room {
     uint64 id = 1;