Check user is host for host-broadcasted project messages

Max Brunsfeld created

Change summary

crates/collab/src/db/queries/projects.rs | 28 +++++++
crates/collab/src/rpc.rs                 | 92 ++++---------------------
crates/rpc/src/macros.rs                 |  4 
crates/rpc/src/proto.rs                  |  5 
4 files changed, 51 insertions(+), 78 deletions(-)

Detailed changes

crates/collab/src/db/queries/projects.rs 🔗

@@ -777,6 +777,34 @@ impl Database {
         .await
     }
 
+    pub async fn check_user_is_project_host(
+        &self,
+        project_id: ProjectId,
+        connection_id: ConnectionId,
+    ) -> Result<()> {
+        let room_id = self.room_id_for_project(project_id).await?;
+        self.room_transaction(room_id, |tx| async move {
+            project_collaborator::Entity::find()
+                .filter(
+                    Condition::all()
+                        .add(project_collaborator::Column::ProjectId.eq(project_id))
+                        .add(project_collaborator::Column::IsHost.eq(true))
+                        .add(project_collaborator::Column::ConnectionId.eq(connection_id.id))
+                        .add(
+                            project_collaborator::Column::ConnectionServerId
+                                .eq(connection_id.owner_id),
+                        ),
+                )
+                .one(&*tx)
+                .await?
+                .ok_or_else(|| anyhow!("failed to read project host"))?;
+
+            Ok(())
+        })
+        .await
+        .map(|guard| guard.into_inner())
+    }
+
     pub async fn host_for_mutating_project_request(
         &self,
         project_id: ProjectId,

crates/collab/src/rpc.rs 🔗

@@ -42,7 +42,7 @@ use prometheus::{register_int_gauge, IntGauge};
 use rpc::{
     proto::{
         self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
-        RequestMessage, UpdateChannelBufferCollaborators,
+        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
     },
     Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
 };
@@ -216,7 +216,6 @@ impl Server {
             .add_message_handler(update_language_server)
             .add_message_handler(update_diagnostic_summary)
             .add_message_handler(update_worktree_settings)
-            .add_message_handler(refresh_inlay_hints)
             .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
             .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
             .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
@@ -251,9 +250,11 @@ impl Server {
             .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
             .add_message_handler(create_buffer_for_peer)
             .add_request_handler(update_buffer)
-            .add_message_handler(update_buffer_file)
-            .add_message_handler(buffer_reloaded)
-            .add_message_handler(buffer_saved)
+            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
+            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
+            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
+            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
+            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
             .add_request_handler(get_users)
             .add_request_handler(fuzzy_search_users)
             .add_request_handler(request_contact)
@@ -285,7 +286,6 @@ impl Server {
             .add_request_handler(follow)
             .add_message_handler(unfollow)
             .add_message_handler(update_followers)
-            .add_message_handler(update_diff_base)
             .add_request_handler(get_private_user_info)
             .add_message_handler(acknowledge_channel_message)
             .add_message_handler(acknowledge_buffer_version);
@@ -1697,10 +1697,6 @@ async fn update_worktree_settings(
     Ok(())
 }
 
-async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
-    broadcast_project_message(request.project_id, request, session).await
-}
-
 async fn start_language_server(
     request: proto::StartLanguageServer,
     session: Session,
@@ -1804,6 +1800,14 @@ async fn create_buffer_for_peer(
     request: proto::CreateBufferForPeer,
     session: Session,
 ) -> Result<()> {
+    session
+        .db()
+        .await
+        .check_user_is_project_host(
+            ProjectId::from_proto(request.project_id),
+            session.connection_id,
+        )
+        .await?;
     let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
     session
         .peer
@@ -1856,60 +1860,17 @@ async fn update_buffer(
     Ok(())
 }
 
-async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
-    let project_id = ProjectId::from_proto(request.project_id);
-    let project_connection_ids = session
-        .db()
-        .await
-        .project_connection_ids(project_id, session.connection_id)
-        .await?;
-
-    broadcast(
-        Some(session.connection_id),
-        project_connection_ids.iter().copied(),
-        |connection_id| {
-            session
-                .peer
-                .forward_send(session.connection_id, connection_id, request.clone())
-        },
-    );
-    Ok(())
-}
-
-async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
-    let project_id = ProjectId::from_proto(request.project_id);
-    let project_connection_ids = session
-        .db()
-        .await
-        .project_connection_ids(project_id, session.connection_id)
-        .await?;
-    broadcast(
-        Some(session.connection_id),
-        project_connection_ids.iter().copied(),
-        |connection_id| {
-            session
-                .peer
-                .forward_send(session.connection_id, connection_id, request.clone())
-        },
-    );
-    Ok(())
-}
-
-async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
-    broadcast_project_message(request.project_id, request, session).await
-}
-
-async fn broadcast_project_message<T: EnvelopedMessage>(
-    project_id: u64,
+async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
     request: T,
     session: Session,
 ) -> Result<()> {
-    let project_id = ProjectId::from_proto(project_id);
+    let project_id = ProjectId::from_proto(request.remote_entity_id());
     let project_connection_ids = session
         .db()
         .await
         .project_connection_ids(project_id, session.connection_id)
         .await?;
+
     broadcast(
         Some(session.connection_id),
         project_connection_ids.iter().copied(),
@@ -3138,25 +3099,6 @@ async fn mark_notification_as_read(
     Ok(())
 }
 
-async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
-    let project_id = ProjectId::from_proto(request.project_id);
-    let project_connection_ids = session
-        .db()
-        .await
-        .project_connection_ids(project_id, session.connection_id)
-        .await?;
-    broadcast(
-        Some(session.connection_id),
-        project_connection_ids.iter().copied(),
-        |connection_id| {
-            session
-                .peer
-                .forward_send(session.connection_id, connection_id, request.clone())
-        },
-    );
-    Ok(())
-}
-
 async fn get_private_user_info(
     _request: proto::GetPrivateUserInfo,
     response: Response<proto::GetPrivateUserInfo>,

crates/rpc/src/macros.rs 🔗

@@ -60,8 +60,10 @@ macro_rules! request_messages {
 
 #[macro_export]
 macro_rules! entity_messages {
-    ($id_field:ident, $($name:ident),* $(,)?) => {
+    ({$id_field:ident, $entity_type:ty}, $($name:ident),* $(,)?) => {
         $(impl EntityMessage for $name {
+            type Entity = $entity_type;
+
             fn remote_entity_id(&self) -> u64 {
                 self.$id_field
             }

crates/rpc/src/proto.rs 🔗

@@ -31,6 +31,7 @@ pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 's
 }
 
 pub trait EntityMessage: EnvelopedMessage {
+    type Entity;
     fn remote_entity_id(&self) -> u64;
 }
 
@@ -369,7 +370,7 @@ request_messages!(
 );
 
 entity_messages!(
-    project_id,
+    {project_id, ShareProject},
     AddProjectCollaborator,
     ApplyCodeAction,
     ApplyCompletionAdditionalEdits,
@@ -422,7 +423,7 @@ entity_messages!(
 );
 
 entity_messages!(
-    channel_id,
+    {channel_id, Channel},
     ChannelMessageSent,
     RemoveChannelMessage,
     UpdateChannelBuffer,