Make project id optional when following - server only

Max Brunsfeld created

Change summary

crates/collab/migrations.sqlite/20221109000000_test_schema.sql              |  2 
crates/collab/migrations/20230918142700_allow_following_without_project.sql |  1 
crates/collab/src/db/queries/projects.rs                                    | 95 
crates/collab/src/db/queries/rooms.rs                                       | 61 
crates/collab/src/db/tables/follower.rs                                     |  2 
crates/collab/src/db/tables/room_participant.rs                             | 10 
crates/collab/src/rpc.rs                                                    | 60 
crates/rpc/proto/zed.proto                                                  | 23 
crates/rpc/src/proto.rs                                                     |  3 
crates/rpc/src/rpc.rs                                                       |  2 
10 files changed, 204 insertions(+), 55 deletions(-)

Detailed changes

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -175,7 +175,7 @@ CREATE TABLE "servers" (
 CREATE TABLE "followers" (
     "id" INTEGER PRIMARY KEY AUTOINCREMENT,
     "room_id" INTEGER NOT NULL REFERENCES rooms (id) ON DELETE CASCADE,
-    "project_id" INTEGER NOT NULL REFERENCES projects (id) ON DELETE CASCADE,
+    "project_id" INTEGER REFERENCES projects (id) ON DELETE CASCADE,
     "leader_connection_server_id" INTEGER NOT NULL REFERENCES servers (id) ON DELETE CASCADE,
     "leader_connection_id" INTEGER NOT NULL,
     "follower_connection_server_id" INTEGER NOT NULL REFERENCES servers (id) ON DELETE CASCADE,

crates/collab/src/db/queries/projects.rs 🔗

@@ -738,7 +738,7 @@ impl Database {
                     Condition::any()
                         .add(
                             Condition::all()
-                                .add(follower::Column::ProjectId.eq(project_id))
+                                .add(follower::Column::ProjectId.eq(Some(project_id)))
                                 .add(
                                     follower::Column::LeaderConnectionServerId
                                         .eq(connection.owner_id),
@@ -747,7 +747,7 @@ impl Database {
                         )
                         .add(
                             Condition::all()
-                                .add(follower::Column::ProjectId.eq(project_id))
+                                .add(follower::Column::ProjectId.eq(Some(project_id)))
                                 .add(
                                     follower::Column::FollowerConnectionServerId
                                         .eq(connection.owner_id),
@@ -862,13 +862,95 @@ impl Database {
         .await
     }
 
+    pub async fn check_can_follow(
+        &self,
+        room_id: RoomId,
+        project_id: Option<ProjectId>,
+        leader_id: ConnectionId,
+        follower_id: ConnectionId,
+    ) -> Result<()> {
+        let mut found_leader = false;
+        let mut found_follower = false;
+        self.transaction(|tx| async move {
+            if let Some(project_id) = project_id {
+                let mut rows = project_collaborator::Entity::find()
+                    .filter(project_collaborator::Column::ProjectId.eq(project_id))
+                    .stream(&*tx)
+                    .await?;
+                while let Some(row) = rows.next().await {
+                    let row = row?;
+                    let connection = row.connection();
+                    if connection == leader_id {
+                        found_leader = true;
+                    } else if connection == follower_id {
+                        found_follower = true;
+                    }
+                }
+            } else {
+                let mut rows = room_participant::Entity::find()
+                    .filter(room_participant::Column::RoomId.eq(room_id))
+                    .stream(&*tx)
+                    .await?;
+                while let Some(row) = rows.next().await {
+                    let row = row?;
+                    if let Some(connection) = row.answering_connection() {
+                        if connection == leader_id {
+                            found_leader = true;
+                        } else if connection == follower_id {
+                            found_follower = true;
+                        }
+                    }
+                }
+            }
+
+            if !found_leader || !found_follower {
+                Err(anyhow!("not a room participant"))?;
+            }
+
+            Ok(())
+        })
+        .await
+    }
+
+    pub async fn check_can_unfollow(
+        &self,
+        room_id: RoomId,
+        project_id: Option<ProjectId>,
+        leader_id: ConnectionId,
+        follower_id: ConnectionId,
+    ) -> Result<()> {
+        self.transaction(|tx| async move {
+            follower::Entity::find()
+                .filter(
+                    Condition::all()
+                        .add(follower::Column::RoomId.eq(room_id))
+                        .add(follower::Column::ProjectId.eq(project_id))
+                        .add(follower::Column::LeaderConnectionId.eq(leader_id.id as i32))
+                        .add(follower::Column::FollowerConnectionId.eq(follower_id.id as i32))
+                        .add(
+                            follower::Column::LeaderConnectionServerId
+                                .eq(leader_id.owner_id as i32),
+                        )
+                        .add(
+                            follower::Column::FollowerConnectionServerId
+                                .eq(follower_id.owner_id as i32),
+                        ),
+                )
+                .one(&*tx)
+                .await?
+                .ok_or_else(|| anyhow!("not a follower"))?;
+            Ok(())
+        })
+        .await
+    }
+
     pub async fn follow(
         &self,
-        project_id: ProjectId,
+        room_id: RoomId,
+        project_id: Option<ProjectId>,
         leader_connection: ConnectionId,
         follower_connection: ConnectionId,
     ) -> Result<RoomGuard<proto::Room>> {
-        let room_id = self.room_id_for_project(project_id).await?;
         self.room_transaction(room_id, |tx| async move {
             follower::ActiveModel {
                 room_id: ActiveValue::set(room_id),
@@ -894,15 +976,16 @@ impl Database {
 
     pub async fn unfollow(
         &self,
-        project_id: ProjectId,
+        room_id: RoomId,
+        project_id: Option<ProjectId>,
         leader_connection: ConnectionId,
         follower_connection: ConnectionId,
     ) -> Result<RoomGuard<proto::Room>> {
-        let room_id = self.room_id_for_project(project_id).await?;
         self.room_transaction(room_id, |tx| async move {
             follower::Entity::delete_many()
                 .filter(
                     Condition::all()
+                        .add(follower::Column::RoomId.eq(room_id))
                         .add(follower::Column::ProjectId.eq(project_id))
                         .add(
                             follower::Column::LeaderConnectionServerId

crates/collab/src/db/queries/rooms.rs 🔗

@@ -960,6 +960,65 @@ impl Database {
         Ok(room)
     }
 
+    pub async fn room_id_for_connection(&self, connection_id: ConnectionId) -> Result<RoomId> {
+        #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)]
+        enum QueryRoomId {
+            RoomId,
+        }
+
+        self.transaction(|tx| async move {
+            Ok(room_participant::Entity::find()
+                .select_only()
+                .column(room_participant::Column::RoomId)
+                .filter(
+                    Condition::all()
+                        .add(room_participant::Column::AnsweringConnectionId.eq(connection_id.id))
+                        .add(
+                            room_participant::Column::AnsweringConnectionServerId
+                                .eq(ServerId(connection_id.owner_id as i32)),
+                        ),
+                )
+                .into_values::<_, QueryRoomId>()
+                .one(&*tx)
+                .await?
+                .ok_or_else(|| anyhow!("no room for connection {:?}", connection_id))?)
+        })
+        .await
+    }
+
+    pub async fn room_connection_ids(
+        &self,
+        room_id: RoomId,
+        connection_id: ConnectionId,
+    ) -> Result<RoomGuard<HashSet<ConnectionId>>> {
+        self.room_transaction(room_id, |tx| async move {
+            let mut participants = room_participant::Entity::find()
+                .filter(room_participant::Column::RoomId.eq(room_id))
+                .stream(&*tx)
+                .await?;
+
+            let mut is_participant = false;
+            let mut connection_ids = HashSet::default();
+            while let Some(participant) = participants.next().await {
+                let participant = participant?;
+                if let Some(answering_connection) = participant.answering_connection() {
+                    if answering_connection == connection_id {
+                        is_participant = true;
+                    } else {
+                        connection_ids.insert(answering_connection);
+                    }
+                }
+            }
+
+            if !is_participant {
+                Err(anyhow!("not a room participant"))?;
+            }
+
+            Ok(connection_ids)
+        })
+        .await
+    }
+
     async fn get_channel_room(
         &self,
         room_id: RoomId,
@@ -1064,7 +1123,7 @@ impl Database {
             followers.push(proto::Follower {
                 leader_id: Some(db_follower.leader_connection().into()),
                 follower_id: Some(db_follower.follower_connection().into()),
-                project_id: db_follower.project_id.to_proto(),
+                project_id: db_follower.project_id.map(|id| id.to_proto()),
             });
         }
 

crates/collab/src/db/tables/follower.rs 🔗

@@ -8,7 +8,7 @@ pub struct Model {
     #[sea_orm(primary_key)]
     pub id: FollowerId,
     pub room_id: RoomId,
-    pub project_id: ProjectId,
+    pub project_id: Option<ProjectId>,
     pub leader_connection_server_id: ServerId,
     pub leader_connection_id: i32,
     pub follower_connection_server_id: ServerId,

crates/collab/src/db/tables/room_participant.rs 🔗

@@ -1,4 +1,5 @@
 use crate::db::{ProjectId, RoomId, RoomParticipantId, ServerId, UserId};
+use rpc::ConnectionId;
 use sea_orm::entity::prelude::*;
 
 #[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)]
@@ -19,6 +20,15 @@ pub struct Model {
     pub calling_connection_server_id: Option<ServerId>,
 }
 
+impl Model {
+    pub fn answering_connection(&self) -> Option<ConnectionId> {
+        Some(ConnectionId {
+            owner_id: self.answering_connection_server_id?.0 as u32,
+            id: self.answering_connection_id? as u32,
+        })
+    }
+}
+
 #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
 pub enum Relation {
     #[sea_orm(

crates/collab/src/rpc.rs 🔗

@@ -1883,24 +1883,19 @@ async fn follow(
     response: Response<proto::Follow>,
     session: Session,
 ) -> Result<()> {
-    let project_id = ProjectId::from_proto(request.project_id);
+    let room_id = RoomId::from_proto(request.room_id);
+    let project_id = request.project_id.map(ProjectId::from_proto);
     let leader_id = request
         .leader_id
         .ok_or_else(|| anyhow!("invalid leader id"))?
         .into();
     let follower_id = session.connection_id;
 
-    {
-        let project_connection_ids = session
-            .db()
-            .await
-            .project_connection_ids(project_id, session.connection_id)
-            .await?;
-
-        if !project_connection_ids.contains(&leader_id) {
-            Err(anyhow!("no such peer"))?;
-        }
-    }
+    session
+        .db()
+        .await
+        .check_can_follow(room_id, project_id, leader_id, session.connection_id)
+        .await?;
 
     let mut response_payload = session
         .peer
@@ -1914,7 +1909,7 @@ async fn follow(
     let room = session
         .db()
         .await
-        .follow(project_id, leader_id, follower_id)
+        .follow(room_id, project_id, leader_id, follower_id)
         .await?;
     room_updated(&room, &session.peer);
 
@@ -1922,22 +1917,19 @@ async fn follow(
 }
 
 async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
-    let project_id = ProjectId::from_proto(request.project_id);
+    let room_id = RoomId::from_proto(request.room_id);
+    let project_id = request.project_id.map(ProjectId::from_proto);
     let leader_id = request
         .leader_id
         .ok_or_else(|| anyhow!("invalid leader id"))?
         .into();
     let follower_id = session.connection_id;
 
-    if !session
+    session
         .db()
         .await
-        .project_connection_ids(project_id, session.connection_id)
-        .await?
-        .contains(&leader_id)
-    {
-        Err(anyhow!("no such peer"))?;
-    }
+        .check_can_unfollow(room_id, project_id, leader_id, session.connection_id)
+        .await?;
 
     session
         .peer
@@ -1946,7 +1938,7 @@ async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
     let room = session
         .db()
         .await
-        .unfollow(project_id, leader_id, follower_id)
+        .unfollow(room_id, project_id, leader_id, follower_id)
         .await?;
     room_updated(&room, &session.peer);
 
@@ -1954,13 +1946,19 @@ async fn unfollow(request: proto::Unfollow, session: Session) -> Result<()> {
 }
 
 async fn update_followers(request: proto::UpdateFollowers, session: Session) -> Result<()> {
-    let project_id = ProjectId::from_proto(request.project_id);
-    let project_connection_ids = session
-        .db
-        .lock()
-        .await
-        .project_connection_ids(project_id, session.connection_id)
-        .await?;
+    let room_id = RoomId::from_proto(request.room_id);
+    let database = session.db.lock().await;
+
+    let connection_ids = if let Some(project_id) = request.project_id {
+        let project_id = ProjectId::from_proto(project_id);
+        database
+            .project_connection_ids(project_id, session.connection_id)
+            .await?
+    } else {
+        database
+            .room_connection_ids(room_id, session.connection_id)
+            .await?
+    };
 
     let leader_id = request.variant.as_ref().and_then(|variant| match variant {
         proto::update_followers::Variant::CreateView(payload) => payload.leader_id,
@@ -1969,9 +1967,7 @@ async fn update_followers(request: proto::UpdateFollowers, session: Session) ->
     });
     for follower_peer_id in request.follower_ids.iter().copied() {
         let follower_connection_id = follower_peer_id.into();
-        if project_connection_ids.contains(&follower_connection_id)
-            && Some(follower_peer_id) != leader_id
-        {
+        if Some(follower_peer_id) != leader_id && connection_ids.contains(&follower_connection_id) {
             session.peer.forward_send(
                 session.connection_id,
                 follower_connection_id,

crates/rpc/proto/zed.proto 🔗

@@ -274,7 +274,7 @@ message ParticipantProject {
 message Follower {
     PeerId leader_id = 1;
     PeerId follower_id = 2;
-    uint64 project_id = 3;
+    optional uint64 project_id = 3;
 }
 
 message ParticipantLocation {
@@ -1213,8 +1213,9 @@ message UpdateDiagnostics {
 }
 
 message Follow {
-    uint64 project_id = 1;
-    PeerId leader_id = 2;
+    uint64 room_id = 1;
+    optional uint64 project_id = 2;
+    PeerId leader_id = 3;
 }
 
 message FollowResponse {
@@ -1223,18 +1224,20 @@ message FollowResponse {
 }
 
 message UpdateFollowers {
-    uint64 project_id = 1;
-    repeated PeerId follower_ids = 2;
+    uint64 room_id = 1;
+    optional uint64 project_id = 2;
+    repeated PeerId follower_ids = 3;
     oneof variant {
-        UpdateActiveView update_active_view = 3;
-        View create_view = 4;
-        UpdateView update_view = 5;
+        UpdateActiveView update_active_view = 4;
+        View create_view = 5;
+        UpdateView update_view = 6;
     }
 }
 
 message Unfollow {
-    uint64 project_id = 1;
-    PeerId leader_id = 2;
+    uint64 room_id = 1;
+    optional uint64 project_id = 2;
+    PeerId leader_id = 3;
 }
 
 message GetPrivateUserInfo {}

crates/rpc/src/proto.rs 🔗

@@ -364,7 +364,6 @@ entity_messages!(
     CreateProjectEntry,
     DeleteProjectEntry,
     ExpandProjectEntry,
-    Follow,
     FormatBuffers,
     GetCodeActions,
     GetCompletions,
@@ -392,12 +391,10 @@ entity_messages!(
     SearchProject,
     StartLanguageServer,
     SynchronizeBuffers,
-    Unfollow,
     UnshareProject,
     UpdateBuffer,
     UpdateBufferFile,
     UpdateDiagnosticSummary,
-    UpdateFollowers,
     UpdateLanguageServer,
     UpdateProject,
     UpdateProjectCollaborator,

crates/rpc/src/rpc.rs 🔗

@@ -6,4 +6,4 @@ pub use conn::Connection;
 pub use peer::*;
 mod macros;
 
-pub const PROTOCOL_VERSION: u32 = 63;
+pub const PROTOCOL_VERSION: u32 = 64;