guest acls (#3958)

Conrad Irwin created

- First pass of real access control
- Check user is host for host-broadcasted project messages
- Only allow read-write users to update buffers

[[PR Description]]

Release Notes:

- (Added|Fixed|Improved) ...
([#<public_issue_number_if_exists>](https://github.com/zed-industries/community/issues/<public_issue_number_if_exists>)).

Change summary

crates/collab/src/db/ids.rs                    |  16 +
crates/collab/src/db/queries/projects.rs       | 118 +++++++++++
crates/collab/src/rpc.rs                       | 196 +++++++------------
crates/collab/src/tests/channel_guest_tests.rs |  10 
crates/rpc/src/macros.rs                       |   4 
crates/rpc/src/proto.rs                        |   5 
6 files changed, 226 insertions(+), 123 deletions(-)

Detailed changes

crates/collab/src/db/ids.rs 🔗

@@ -140,6 +140,22 @@ impl ChannelRole {
             Guest | Banned => false,
         }
     }
+
+    pub fn can_edit_projects(&self) -> bool {
+        use ChannelRole::*;
+        match self {
+            Admin | Member => true,
+            Guest | Banned => false,
+        }
+    }
+
+    pub fn can_read_projects(&self) -> bool {
+        use ChannelRole::*;
+        match self {
+            Admin | Member | Guest => true,
+            Banned => false,
+        }
+    }
 }
 
 impl From<proto::ChannelRole> for ChannelRole {

crates/collab/src/db/queries/projects.rs 🔗

@@ -777,13 +777,129 @@ impl Database {
         .await
     }
 
-    pub async fn project_collaborators(
+    pub async fn check_user_is_project_host(
+        &self,
+        project_id: ProjectId,
+        connection_id: ConnectionId,
+    ) -> Result<()> {
+        let room_id = self.room_id_for_project(project_id).await?;
+        self.room_transaction(room_id, |tx| async move {
+            project_collaborator::Entity::find()
+                .filter(
+                    Condition::all()
+                        .add(project_collaborator::Column::ProjectId.eq(project_id))
+                        .add(project_collaborator::Column::IsHost.eq(true))
+                        .add(project_collaborator::Column::ConnectionId.eq(connection_id.id))
+                        .add(
+                            project_collaborator::Column::ConnectionServerId
+                                .eq(connection_id.owner_id),
+                        ),
+                )
+                .one(&*tx)
+                .await?
+                .ok_or_else(|| anyhow!("failed to read project host"))?;
+
+            Ok(())
+        })
+        .await
+        .map(|guard| guard.into_inner())
+    }
+
+    pub async fn host_for_read_only_project_request(
+        &self,
+        project_id: ProjectId,
+        connection_id: ConnectionId,
+    ) -> Result<ConnectionId> {
+        let room_id = self.room_id_for_project(project_id).await?;
+        self.room_transaction(room_id, |tx| async move {
+            let current_participant = room_participant::Entity::find()
+                .filter(room_participant::Column::RoomId.eq(room_id))
+                .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.id))
+                .one(&*tx)
+                .await?
+                .ok_or_else(|| anyhow!("no such room"))?;
+
+            if !current_participant
+                .role
+                .map_or(false, |role| role.can_read_projects())
+            {
+                Err(anyhow!("not authorized to read projects"))?;
+            }
+
+            let host = project_collaborator::Entity::find()
+                .filter(
+                    project_collaborator::Column::ProjectId
+                        .eq(project_id)
+                        .and(project_collaborator::Column::IsHost.eq(true)),
+                )
+                .one(&*tx)
+                .await?
+                .ok_or_else(|| anyhow!("failed to read project host"))?;
+
+            Ok(host.connection())
+        })
+        .await
+        .map(|guard| guard.into_inner())
+    }
+
+    pub async fn host_for_mutating_project_request(
+        &self,
+        project_id: ProjectId,
+        connection_id: ConnectionId,
+    ) -> Result<ConnectionId> {
+        let room_id = self.room_id_for_project(project_id).await?;
+        self.room_transaction(room_id, |tx| async move {
+            let current_participant = room_participant::Entity::find()
+                .filter(room_participant::Column::RoomId.eq(room_id))
+                .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.id))
+                .one(&*tx)
+                .await?
+                .ok_or_else(|| anyhow!("no such room"))?;
+
+            if !current_participant
+                .role
+                .map_or(false, |role| role.can_edit_projects())
+            {
+                Err(anyhow!("not authorized to edit projects"))?;
+            }
+
+            let host = project_collaborator::Entity::find()
+                .filter(
+                    project_collaborator::Column::ProjectId
+                        .eq(project_id)
+                        .and(project_collaborator::Column::IsHost.eq(true)),
+                )
+                .one(&*tx)
+                .await?
+                .ok_or_else(|| anyhow!("failed to read project host"))?;
+
+            Ok(host.connection())
+        })
+        .await
+        .map(|guard| guard.into_inner())
+    }
+
+    pub async fn project_collaborators_for_buffer_update(
         &self,
         project_id: ProjectId,
         connection_id: ConnectionId,
     ) -> Result<RoomGuard<Vec<ProjectCollaborator>>> {
         let room_id = self.room_id_for_project(project_id).await?;
         self.room_transaction(room_id, |tx| async move {
+            let current_participant = room_participant::Entity::find()
+                .filter(room_participant::Column::RoomId.eq(room_id))
+                .filter(room_participant::Column::AnsweringConnectionId.eq(connection_id.id))
+                .one(&*tx)
+                .await?
+                .ok_or_else(|| anyhow!("no such room"))?;
+
+            if !current_participant
+                .role
+                .map_or(false, |role| role.can_edit_projects())
+            {
+                Err(anyhow!("not authorized to edit projects"))?;
+            }
+
             let collaborators = project_collaborator::Entity::find()
                 .filter(project_collaborator::Column::ProjectId.eq(project_id))
                 .all(&*tx)

crates/collab/src/rpc.rs 🔗

@@ -42,7 +42,7 @@ use prometheus::{register_int_gauge, IntGauge};
 use rpc::{
     proto::{
         self, Ack, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, LiveKitConnectionInfo,
-        RequestMessage, UpdateChannelBufferCollaborators,
+        RequestMessage, ShareProject, UpdateChannelBufferCollaborators,
     },
     Connection, ConnectionId, Peer, Receipt, TypedEnvelope,
 };
@@ -216,40 +216,45 @@ impl Server {
             .add_message_handler(update_language_server)
             .add_message_handler(update_diagnostic_summary)
             .add_message_handler(update_worktree_settings)
-            .add_message_handler(refresh_inlay_hints)
-            .add_request_handler(forward_project_request::<proto::GetHover>)
-            .add_request_handler(forward_project_request::<proto::GetDefinition>)
-            .add_request_handler(forward_project_request::<proto::GetTypeDefinition>)
-            .add_request_handler(forward_project_request::<proto::GetReferences>)
-            .add_request_handler(forward_project_request::<proto::SearchProject>)
-            .add_request_handler(forward_project_request::<proto::GetDocumentHighlights>)
-            .add_request_handler(forward_project_request::<proto::GetProjectSymbols>)
-            .add_request_handler(forward_project_request::<proto::OpenBufferForSymbol>)
-            .add_request_handler(forward_project_request::<proto::OpenBufferById>)
-            .add_request_handler(forward_project_request::<proto::OpenBufferByPath>)
-            .add_request_handler(forward_project_request::<proto::GetCompletions>)
-            .add_request_handler(forward_project_request::<proto::ApplyCompletionAdditionalEdits>)
-            .add_request_handler(forward_project_request::<proto::ResolveCompletionDocumentation>)
-            .add_request_handler(forward_project_request::<proto::GetCodeActions>)
-            .add_request_handler(forward_project_request::<proto::ApplyCodeAction>)
-            .add_request_handler(forward_project_request::<proto::PrepareRename>)
-            .add_request_handler(forward_project_request::<proto::PerformRename>)
-            .add_request_handler(forward_project_request::<proto::ReloadBuffers>)
-            .add_request_handler(forward_project_request::<proto::SynchronizeBuffers>)
-            .add_request_handler(forward_project_request::<proto::FormatBuffers>)
-            .add_request_handler(forward_project_request::<proto::CreateProjectEntry>)
-            .add_request_handler(forward_project_request::<proto::RenameProjectEntry>)
-            .add_request_handler(forward_project_request::<proto::CopyProjectEntry>)
-            .add_request_handler(forward_project_request::<proto::DeleteProjectEntry>)
-            .add_request_handler(forward_project_request::<proto::ExpandProjectEntry>)
-            .add_request_handler(forward_project_request::<proto::OnTypeFormatting>)
-            .add_request_handler(forward_project_request::<proto::InlayHints>)
+            .add_request_handler(forward_read_only_project_request::<proto::GetHover>)
+            .add_request_handler(forward_read_only_project_request::<proto::GetDefinition>)
+            .add_request_handler(forward_read_only_project_request::<proto::GetTypeDefinition>)
+            .add_request_handler(forward_read_only_project_request::<proto::GetReferences>)
+            .add_request_handler(forward_read_only_project_request::<proto::SearchProject>)
+            .add_request_handler(forward_read_only_project_request::<proto::GetDocumentHighlights>)
+            .add_request_handler(forward_read_only_project_request::<proto::GetProjectSymbols>)
+            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferForSymbol>)
+            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferById>)
+            .add_request_handler(forward_read_only_project_request::<proto::SynchronizeBuffers>)
+            .add_request_handler(forward_read_only_project_request::<proto::InlayHints>)
+            .add_request_handler(forward_read_only_project_request::<proto::OpenBufferByPath>)
+            .add_request_handler(forward_mutating_project_request::<proto::GetCompletions>)
+            .add_request_handler(
+                forward_mutating_project_request::<proto::ApplyCompletionAdditionalEdits>,
+            )
+            .add_request_handler(
+                forward_mutating_project_request::<proto::ResolveCompletionDocumentation>,
+            )
+            .add_request_handler(forward_mutating_project_request::<proto::GetCodeActions>)
+            .add_request_handler(forward_mutating_project_request::<proto::ApplyCodeAction>)
+            .add_request_handler(forward_mutating_project_request::<proto::PrepareRename>)
+            .add_request_handler(forward_mutating_project_request::<proto::PerformRename>)
+            .add_request_handler(forward_mutating_project_request::<proto::ReloadBuffers>)
+            .add_request_handler(forward_mutating_project_request::<proto::FormatBuffers>)
+            .add_request_handler(forward_mutating_project_request::<proto::CreateProjectEntry>)
+            .add_request_handler(forward_mutating_project_request::<proto::RenameProjectEntry>)
+            .add_request_handler(forward_mutating_project_request::<proto::CopyProjectEntry>)
+            .add_request_handler(forward_mutating_project_request::<proto::DeleteProjectEntry>)
+            .add_request_handler(forward_mutating_project_request::<proto::ExpandProjectEntry>)
+            .add_request_handler(forward_mutating_project_request::<proto::OnTypeFormatting>)
+            .add_request_handler(forward_mutating_project_request::<proto::SaveBuffer>)
             .add_message_handler(create_buffer_for_peer)
             .add_request_handler(update_buffer)
-            .add_message_handler(update_buffer_file)
-            .add_message_handler(buffer_reloaded)
-            .add_message_handler(buffer_saved)
-            .add_request_handler(forward_project_request::<proto::SaveBuffer>)
+            .add_message_handler(broadcast_project_message_from_host::<proto::RefreshInlayHints>)
+            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateBufferFile>)
+            .add_message_handler(broadcast_project_message_from_host::<proto::BufferReloaded>)
+            .add_message_handler(broadcast_project_message_from_host::<proto::BufferSaved>)
+            .add_message_handler(broadcast_project_message_from_host::<proto::UpdateDiffBase>)
             .add_request_handler(get_users)
             .add_request_handler(fuzzy_search_users)
             .add_request_handler(request_contact)
@@ -281,7 +286,6 @@ impl Server {
             .add_request_handler(follow)
             .add_message_handler(unfollow)
             .add_message_handler(update_followers)
-            .add_message_handler(update_diff_base)
             .add_request_handler(get_private_user_info)
             .add_message_handler(acknowledge_channel_message)
             .add_message_handler(acknowledge_buffer_version);
@@ -1694,10 +1698,6 @@ async fn update_worktree_settings(
     Ok(())
 }
 
-async fn refresh_inlay_hints(request: proto::RefreshInlayHints, session: Session) -> Result<()> {
-    broadcast_project_message(request.project_id, request, session).await
-}
-
 async fn start_language_server(
     request: proto::StartLanguageServer,
     session: Session,
@@ -1742,7 +1742,7 @@ async fn update_language_server(
     Ok(())
 }
 
-async fn forward_project_request<T>(
+async fn forward_read_only_project_request<T>(
     request: T,
     response: Response<T>,
     session: Session,
@@ -1751,24 +1751,37 @@ where
     T: EntityMessage + RequestMessage,
 {
     let project_id = ProjectId::from_proto(request.remote_entity_id());
-    let host_connection_id = {
-        let collaborators = session
-            .db()
-            .await
-            .project_collaborators(project_id, session.connection_id)
-            .await?;
-        collaborators
-            .iter()
-            .find(|collaborator| collaborator.is_host)
-            .ok_or_else(|| anyhow!("host not found"))?
-            .connection_id
-    };
-
+    let host_connection_id = session
+        .db()
+        .await
+        .host_for_read_only_project_request(project_id, session.connection_id)
+        .await?;
     let payload = session
         .peer
         .forward_request(session.connection_id, host_connection_id, request)
         .await?;
+    response.send(payload)?;
+    Ok(())
+}
 
+async fn forward_mutating_project_request<T>(
+    request: T,
+    response: Response<T>,
+    session: Session,
+) -> Result<()>
+where
+    T: EntityMessage + RequestMessage,
+{
+    let project_id = ProjectId::from_proto(request.remote_entity_id());
+    let host_connection_id = session
+        .db()
+        .await
+        .host_for_mutating_project_request(project_id, session.connection_id)
+        .await?;
+    let payload = session
+        .peer
+        .forward_request(session.connection_id, host_connection_id, request)
+        .await?;
     response.send(payload)?;
     Ok(())
 }
@@ -1777,6 +1790,14 @@ async fn create_buffer_for_peer(
     request: proto::CreateBufferForPeer,
     session: Session,
 ) -> Result<()> {
+    session
+        .db()
+        .await
+        .check_user_is_project_host(
+            ProjectId::from_proto(request.project_id),
+            session.connection_id,
+        )
+        .await?;
     let peer_id = request.peer_id.ok_or_else(|| anyhow!("invalid peer id"))?;
     session
         .peer
@@ -1792,11 +1813,12 @@ async fn update_buffer(
     let project_id = ProjectId::from_proto(request.project_id);
     let mut guest_connection_ids;
     let mut host_connection_id = None;
+
     {
         let collaborators = session
             .db()
             .await
-            .project_collaborators(project_id, session.connection_id)
+            .project_collaborators_for_buffer_update(project_id, session.connection_id)
             .await?;
         guest_connection_ids = Vec::with_capacity(collaborators.len() - 1);
         for collaborator in collaborators.iter() {
@@ -1829,60 +1851,17 @@ async fn update_buffer(
     Ok(())
 }
 
-async fn update_buffer_file(request: proto::UpdateBufferFile, session: Session) -> Result<()> {
-    let project_id = ProjectId::from_proto(request.project_id);
-    let project_connection_ids = session
-        .db()
-        .await
-        .project_connection_ids(project_id, session.connection_id)
-        .await?;
-
-    broadcast(
-        Some(session.connection_id),
-        project_connection_ids.iter().copied(),
-        |connection_id| {
-            session
-                .peer
-                .forward_send(session.connection_id, connection_id, request.clone())
-        },
-    );
-    Ok(())
-}
-
-async fn buffer_reloaded(request: proto::BufferReloaded, session: Session) -> Result<()> {
-    let project_id = ProjectId::from_proto(request.project_id);
-    let project_connection_ids = session
-        .db()
-        .await
-        .project_connection_ids(project_id, session.connection_id)
-        .await?;
-    broadcast(
-        Some(session.connection_id),
-        project_connection_ids.iter().copied(),
-        |connection_id| {
-            session
-                .peer
-                .forward_send(session.connection_id, connection_id, request.clone())
-        },
-    );
-    Ok(())
-}
-
-async fn buffer_saved(request: proto::BufferSaved, session: Session) -> Result<()> {
-    broadcast_project_message(request.project_id, request, session).await
-}
-
-async fn broadcast_project_message<T: EnvelopedMessage>(
-    project_id: u64,
+async fn broadcast_project_message_from_host<T: EntityMessage<Entity = ShareProject>>(
     request: T,
     session: Session,
 ) -> Result<()> {
-    let project_id = ProjectId::from_proto(project_id);
+    let project_id = ProjectId::from_proto(request.remote_entity_id());
     let project_connection_ids = session
         .db()
         .await
         .project_connection_ids(project_id, session.connection_id)
         .await?;
+
     broadcast(
         Some(session.connection_id),
         project_connection_ids.iter().copied(),
@@ -3111,25 +3090,6 @@ async fn mark_notification_as_read(
     Ok(())
 }
 
-async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> {
-    let project_id = ProjectId::from_proto(request.project_id);
-    let project_connection_ids = session
-        .db()
-        .await
-        .project_connection_ids(project_id, session.connection_id)
-        .await?;
-    broadcast(
-        Some(session.connection_id),
-        project_connection_ids.iter().copied(),
-        |connection_id| {
-            session
-                .peer
-                .forward_send(session.connection_id, connection_id, request.clone())
-        },
-    );
-    Ok(())
-}
-
 async fn get_private_user_info(
     _request: proto::GetPrivateUserInfo,
     response: Response<proto::GetPrivateUserInfo>,

crates/collab/src/tests/channel_guest_tests.rs 🔗

@@ -82,5 +82,13 @@ async fn test_channel_guests(
         project_b.read_with(cx_b, |project, _| project.remote_id()),
         Some(project_id),
     );
-    assert!(project_b.read_with(cx_b, |project, _| project.is_read_only()))
+    assert!(project_b.read_with(cx_b, |project, _| project.is_read_only()));
+
+    assert!(project_b
+        .update(cx_b, |project, cx| {
+            let worktree_id = project.worktrees().next().unwrap().read(cx).id();
+            project.create_entry((worktree_id, "b.txt"), false, cx)
+        })
+        .await
+        .is_err())
 }

crates/rpc/src/macros.rs 🔗

@@ -60,8 +60,10 @@ macro_rules! request_messages {
 
 #[macro_export]
 macro_rules! entity_messages {
-    ($id_field:ident, $($name:ident),* $(,)?) => {
+    ({$id_field:ident, $entity_type:ty}, $($name:ident),* $(,)?) => {
         $(impl EntityMessage for $name {
+            type Entity = $entity_type;
+
             fn remote_entity_id(&self) -> u64 {
                 self.$id_field
             }

crates/rpc/src/proto.rs 🔗

@@ -31,6 +31,7 @@ pub trait EnvelopedMessage: Clone + Debug + Serialize + Sized + Send + Sync + 's
 }
 
 pub trait EntityMessage: EnvelopedMessage {
+    type Entity;
     fn remote_entity_id(&self) -> u64;
 }
 
@@ -369,7 +370,7 @@ request_messages!(
 );
 
 entity_messages!(
-    project_id,
+    {project_id, ShareProject},
     AddProjectCollaborator,
     ApplyCodeAction,
     ApplyCompletionAdditionalEdits,
@@ -422,7 +423,7 @@ entity_messages!(
 );
 
 entity_messages!(
-    channel_id,
+    {channel_id, Channel},
     ChannelMessageSent,
     RemoveChannelMessage,
     UpdateChannelBuffer,