Allow deleting chat messages

Max Brunsfeld created

Change summary

crates/channel/src/channel_chat.rs               | 46 ++++++++
crates/collab/src/db/queries/messages.rs         | 40 +++++++
crates/collab/src/rpc.rs                         | 20 +++
crates/collab/src/tests/channel_message_tests.rs | 97 ++++++++++++++++-
crates/collab_ui/src/chat_panel.rs               | 76 +++++++++++--
crates/rpc/proto/zed.proto                       |  8 +
crates/rpc/src/proto.rs                          |  3 
7 files changed, 266 insertions(+), 24 deletions(-)

Detailed changes

crates/channel/src/channel_chat.rs 🔗

@@ -59,6 +59,7 @@ pub enum ChannelChatEvent {
 
 pub fn init(client: &Arc<Client>) {
     client.add_model_message_handler(ChannelChat::handle_message_sent);
+    client.add_model_message_handler(ChannelChat::handle_message_removed);
 }
 
 impl Entity for ChannelChat {
@@ -166,6 +167,21 @@ impl ChannelChat {
         }))
     }
 
+    pub fn remove_message(&mut self, id: u64, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
+        let response = self.rpc.request(proto::RemoveChannelMessage {
+            channel_id: self.channel.id,
+            message_id: id,
+        });
+        cx.spawn(|this, mut cx| async move {
+            response.await?;
+
+            this.update(&mut cx, |this, cx| {
+                this.message_removed(id, cx);
+                Ok(())
+            })
+        })
+    }
+
     pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> bool {
         if !self.loaded_all_messages {
             let rpc = self.rpc.clone();
@@ -306,6 +322,18 @@ impl ChannelChat {
         Ok(())
     }
 
+    async fn handle_message_removed(
+        this: ModelHandle<Self>,
+        message: TypedEnvelope<proto::RemoveChannelMessage>,
+        _: Arc<Client>,
+        mut cx: AsyncAppContext,
+    ) -> Result<()> {
+        this.update(&mut cx, |this, cx| {
+            this.message_removed(message.payload.message_id, cx)
+        });
+        Ok(())
+    }
+
     fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
         if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
             let nonces = messages
@@ -363,6 +391,24 @@ impl ChannelChat {
             cx.notify();
         }
     }
+
+    fn message_removed(&mut self, id: u64, cx: &mut ModelContext<Self>) {
+        let mut cursor = self.messages.cursor::<ChannelMessageId>();
+        let mut messages = cursor.slice(&ChannelMessageId::Saved(id), Bias::Left, &());
+        if let Some(item) = cursor.item() {
+            if item.id == ChannelMessageId::Saved(id) {
+                let ix = messages.summary().count;
+                cursor.next(&());
+                messages.append(cursor.suffix(&()), &());
+                drop(cursor);
+                self.messages = messages;
+                cx.emit(ChannelChatEvent::MessagesUpdated {
+                    old_range: ix..ix + 1,
+                    new_count: 0,
+                });
+            }
+        }
+    }
 }
 
 async fn messages_from_proto(

crates/collab/src/db/queries/messages.rs 🔗

@@ -171,4 +171,44 @@ impl Database {
         })
         .await
     }
+
+    pub async fn remove_channel_message(
+        &self,
+        channel_id: ChannelId,
+        message_id: MessageId,
+        user_id: UserId,
+    ) -> Result<Vec<ConnectionId>> {
+        self.transaction(|tx| async move {
+            let mut rows = channel_chat_participant::Entity::find()
+                .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
+                .stream(&*tx)
+                .await?;
+
+            let mut is_participant = false;
+            let mut participant_connection_ids = Vec::new();
+            while let Some(row) = rows.next().await {
+                let row = row?;
+                if row.user_id == user_id {
+                    is_participant = true;
+                }
+                participant_connection_ids.push(row.connection());
+            }
+            drop(rows);
+
+            if !is_participant {
+                Err(anyhow!("not a chat participant"))?;
+            }
+
+            let result = channel_message::Entity::delete_by_id(message_id)
+                .filter(channel_message::Column::SenderId.eq(user_id))
+                .exec(&*tx)
+                .await?;
+            if result.rows_affected == 0 {
+                Err(anyhow!("no such message"))?;
+            }
+
+            Ok(participant_connection_ids)
+        })
+        .await
+    }
 }

crates/collab/src/rpc.rs 🔗

@@ -265,6 +265,7 @@ impl Server {
             .add_request_handler(join_channel_chat)
             .add_message_handler(leave_channel_chat)
             .add_request_handler(send_channel_message)
+            .add_request_handler(remove_channel_message)
             .add_request_handler(get_channel_messages)
             .add_request_handler(follow)
             .add_message_handler(unfollow)
@@ -2696,6 +2697,25 @@ async fn send_channel_message(
     Ok(())
 }
 
+async fn remove_channel_message(
+    request: proto::RemoveChannelMessage,
+    response: Response<proto::RemoveChannelMessage>,
+    session: Session,
+) -> Result<()> {
+    let channel_id = ChannelId::from_proto(request.channel_id);
+    let message_id = MessageId::from_proto(request.message_id);
+    let connection_ids = session
+        .db()
+        .await
+        .remove_channel_message(channel_id, message_id, session.user_id)
+        .await?;
+    broadcast(Some(session.connection_id), connection_ids, |connection| {
+        session.peer.send(connection, request.clone())
+    });
+    response.send(proto::Ack {})?;
+    Ok(())
+}
+
 async fn join_channel_chat(
     request: proto::JoinChannelChat,
     response: Response<proto::JoinChannelChat>,

crates/collab/src/tests/channel_message_tests.rs 🔗

@@ -1,5 +1,5 @@
 use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer};
-use channel::ChannelChat;
+use channel::{ChannelChat, ChannelMessageId};
 use gpui::{executor::Deterministic, ModelHandle, TestAppContext};
 use std::sync::Arc;
 
@@ -123,15 +123,92 @@ async fn test_rejoin_channel_chat(
     assert_messages(&channel_chat_b, expected_messages, cx_b);
 }
 
+#[gpui::test]
+async fn test_remove_channel_message(
+    deterministic: Arc<Deterministic>,
+    cx_a: &mut TestAppContext,
+    cx_b: &mut TestAppContext,
+    cx_c: &mut TestAppContext,
+) {
+    deterministic.forbid_parking();
+    let mut server = TestServer::start(&deterministic).await;
+    let client_a = server.create_client(cx_a, "user_a").await;
+    let client_b = server.create_client(cx_b, "user_b").await;
+    let client_c = server.create_client(cx_c, "user_c").await;
+
+    let channel_id = server
+        .make_channel(
+            "the-channel",
+            (&client_a, cx_a),
+            &mut [(&client_b, cx_b), (&client_c, cx_c)],
+        )
+        .await;
+
+    let channel_chat_a = client_a
+        .channel_store()
+        .update(cx_a, |store, cx| store.open_channel_chat(channel_id, cx))
+        .await
+        .unwrap();
+    let channel_chat_b = client_b
+        .channel_store()
+        .update(cx_b, |store, cx| store.open_channel_chat(channel_id, cx))
+        .await
+        .unwrap();
+
+    // Client A sends some messages.
+    channel_chat_a
+        .update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap())
+        .await
+        .unwrap();
+    channel_chat_a
+        .update(cx_a, |c, cx| c.send_message("two".into(), cx).unwrap())
+        .await
+        .unwrap();
+    channel_chat_a
+        .update(cx_a, |c, cx| c.send_message("three".into(), cx).unwrap())
+        .await
+        .unwrap();
+
+    // Clients A and B see all of the messages.
+    deterministic.run_until_parked();
+    let expected_messages = &["one", "two", "three"];
+    assert_messages(&channel_chat_a, expected_messages, cx_a);
+    assert_messages(&channel_chat_b, expected_messages, cx_b);
+
+    // Client A deletes one of their messages.
+    channel_chat_a
+        .update(cx_a, |c, cx| {
+            let ChannelMessageId::Saved(id) = c.message(1).id else {
+                panic!("message not saved")
+            };
+            c.remove_message(id, cx)
+        })
+        .await
+        .unwrap();
+
+    // Client B sees that the message is gone.
+    deterministic.run_until_parked();
+    let expected_messages = &["one", "three"];
+    assert_messages(&channel_chat_a, expected_messages, cx_a);
+    assert_messages(&channel_chat_b, expected_messages, cx_b);
+
+    // Client C joins the channel chat, and does not see the deleted message.
+    let channel_chat_c = client_c
+        .channel_store()
+        .update(cx_c, |store, cx| store.open_channel_chat(channel_id, cx))
+        .await
+        .unwrap();
+    assert_messages(&channel_chat_c, expected_messages, cx_c);
+}
+
 #[track_caller]
 fn assert_messages(chat: &ModelHandle<ChannelChat>, messages: &[&str], cx: &mut TestAppContext) {
-    chat.update(cx, |chat, _| {
-        assert_eq!(
-            chat.messages()
-                .iter()
-                .map(|m| m.body.as_str())
-                .collect::<Vec<_>>(),
-            messages
-        );
-    })
+    assert_eq!(
+        chat.read_with(cx, |chat, _| chat
+            .messages()
+            .iter()
+            .map(|m| m.body.clone())
+            .collect::<Vec<_>>(),),
+        messages
+    );
 }

crates/collab_ui/src/chat_panel.rs 🔗

@@ -1,6 +1,6 @@
 use crate::ChatPanelSettings;
 use anyhow::Result;
-use channel::{ChannelChat, ChannelChatEvent, ChannelMessage, ChannelStore};
+use channel::{ChannelChat, ChannelChatEvent, ChannelMessageId, ChannelStore};
 use client::Client;
 use db::kvp::KEY_VALUE_STORE;
 use editor::Editor;
@@ -19,7 +19,7 @@ use project::Fs;
 use serde::{Deserialize, Serialize};
 use settings::SettingsStore;
 use std::sync::Arc;
-use theme::Theme;
+use theme::{IconButton, Theme};
 use time::{OffsetDateTime, UtcOffset};
 use util::{ResultExt, TryFutureExt};
 use workspace::{
@@ -105,8 +105,7 @@ impl ChatPanel {
 
         let mut message_list =
             ListState::<Self>::new(0, Orientation::Bottom, 1000., move |this, ix, cx| {
-                let message = this.active_chat.as_ref().unwrap().0.read(cx).message(ix);
-                this.render_message(message, cx)
+                this.render_message(ix, cx)
             });
         message_list.set_scroll_handler(|visible_range, this, cx| {
             if visible_range.start < MESSAGE_LOADING_THRESHOLD {
@@ -285,38 +284,70 @@ impl ChatPanel {
         messages.flex(1., true).into_any()
     }
 
-    fn render_message(&self, message: &ChannelMessage, cx: &AppContext) -> AnyElement<Self> {
+    fn render_message(&self, ix: usize, cx: &mut ViewContext<Self>) -> AnyElement<Self> {
+        let message = self.active_chat.as_ref().unwrap().0.read(cx).message(ix);
+
         let now = OffsetDateTime::now_utc();
         let theme = theme::current(cx);
-        let theme = if message.is_pending() {
+        let style = if message.is_pending() {
             &theme.chat_panel.pending_message
         } else {
             &theme.chat_panel.message
         };
 
+        let belongs_to_user = Some(message.sender.id) == self.client.user_id();
+        let message_id_to_remove =
+            if let (ChannelMessageId::Saved(id), true) = (message.id, belongs_to_user) {
+                Some(id)
+            } else {
+                None
+            };
+
+        enum DeleteMessage {}
+
+        let body = message.body.clone();
         Flex::column()
             .with_child(
                 Flex::row()
                     .with_child(
                         Label::new(
                             message.sender.github_login.clone(),
-                            theme.sender.text.clone(),
+                            style.sender.text.clone(),
                         )
                         .contained()
-                        .with_style(theme.sender.container),
+                        .with_style(style.sender.container),
                     )
                     .with_child(
                         Label::new(
                             format_timestamp(message.timestamp, now, self.local_timezone),
-                            theme.timestamp.text.clone(),
+                            style.timestamp.text.clone(),
                         )
                         .contained()
-                        .with_style(theme.timestamp.container),
-                    ),
+                        .with_style(style.timestamp.container),
+                    )
+                    .with_children(message_id_to_remove.map(|id| {
+                        MouseEventHandler::new::<DeleteMessage, _>(
+                            id as usize,
+                            cx,
+                            |mouse_state, _| {
+                                let button_style =
+                                    theme.collab_panel.contact_button.style_for(mouse_state);
+                                render_icon_button(button_style, "icons/x.svg")
+                                    .aligned()
+                                    .into_any()
+                            },
+                        )
+                        .with_padding(Padding::uniform(2.))
+                        .with_cursor_style(CursorStyle::PointingHand)
+                        .on_click(MouseButton::Left, move |_, this, cx| {
+                            this.remove_message(id, cx);
+                        })
+                        .flex_float()
+                    })),
             )
-            .with_child(Text::new(message.body.clone(), theme.body.clone()))
+            .with_child(Text::new(body, style.body.clone()))
             .contained()
-            .with_style(theme.container)
+            .with_style(style.container)
             .into_any()
     }
 
@@ -413,6 +444,12 @@ impl ChatPanel {
         }
     }
 
+    fn remove_message(&mut self, id: u64, cx: &mut ViewContext<Self>) {
+        if let Some((chat, _)) = self.active_chat.as_ref() {
+            chat.update(cx, |chat, cx| chat.remove_message(id, cx).detach())
+        }
+    }
+
     fn load_more_messages(&mut self, _: &LoadMoreMessages, cx: &mut ViewContext<Self>) {
         if let Some((chat, _)) = self.active_chat.as_ref() {
             chat.update(cx, |channel, cx| {
@@ -551,3 +588,16 @@ fn format_timestamp(
         format!("{:02}/{}/{}", date.month() as u32, date.day(), date.year())
     }
 }
+
+fn render_icon_button(style: &IconButton, svg_path: &'static str) -> impl Element<ChatPanel> {
+    Svg::new(svg_path)
+        .with_color(style.color)
+        .constrained()
+        .with_width(style.icon_width)
+        .aligned()
+        .constrained()
+        .with_width(style.button_width)
+        .with_height(style.button_width)
+        .contained()
+        .with_style(style.container)
+}

crates/rpc/proto/zed.proto 🔗

@@ -164,7 +164,8 @@ message Envelope {
         SendChannelMessageResponse send_channel_message_response = 146;
         ChannelMessageSent channel_message_sent = 147;
         GetChannelMessages get_channel_messages = 148;
-        GetChannelMessagesResponse get_channel_messages_response = 149; // Current max
+        GetChannelMessagesResponse get_channel_messages_response = 149;
+        RemoveChannelMessage remove_channel_message = 150; // Current max
     }
 }
 
@@ -1049,6 +1050,11 @@ message SendChannelMessage {
     Nonce nonce = 3;
 }
 
+message RemoveChannelMessage {
+    uint64 channel_id = 1;
+    uint64 message_id = 2;
+}
+
 message SendChannelMessageResponse {
     ChannelMessage message = 1;
 }

crates/rpc/src/proto.rs 🔗

@@ -217,6 +217,7 @@ messages!(
     (RejoinRoomResponse, Foreground),
     (RemoveContact, Foreground),
     (RemoveChannelMember, Foreground),
+    (RemoveChannelMessage, Foreground),
     (ReloadBuffers, Foreground),
     (ReloadBuffersResponse, Foreground),
     (RemoveProjectCollaborator, Foreground),
@@ -327,6 +328,7 @@ request_messages!(
     (GetChannelMembers, GetChannelMembersResponse),
     (JoinChannel, JoinRoomResponse),
     (RemoveChannel, Ack),
+    (RemoveChannelMessage, Ack),
     (RenameProjectEntry, ProjectEntryResponse),
     (RenameChannel, ChannelResponse),
     (SaveBuffer, BufferSaved),
@@ -402,6 +404,7 @@ entity_messages!(
     ChannelMessageSent,
     UpdateChannelBuffer,
     RemoveChannelBufferCollaborator,
+    RemoveChannelMessage,
     AddChannelBufferCollaborator,
     UpdateChannelBufferCollaborator
 );