Trim whitespace from chat messages and limit their length

Max Brunsfeld created

Add a way for the server to respond to any request with an error

Change summary

server/src/rpc.rs     | 107 +++++++++++++++++++++++++++++++++++++++-----
zed/src/channel.rs    |  70 ++++++++++++++---------------
zed/src/chat_panel.rs |   7 ++
zrpc/proto/zed.proto  |  61 +++++++++++++-----------
zrpc/src/peer.rs      |  27 ++++++++++
zrpc/src/proto.rs     |   1 
6 files changed, 191 insertions(+), 82 deletions(-)

Detailed changes

server/src/rpc.rs 🔗

@@ -77,8 +77,14 @@ struct Channel {
     connection_ids: HashSet<ConnectionId>,
 }
 
+#[cfg(debug_assertions)]
+const MESSAGE_COUNT_PER_PAGE: usize = 10;
+
+#[cfg(not(debug_assertions))]
 const MESSAGE_COUNT_PER_PAGE: usize = 50;
 
+const MAX_MESSAGE_LEN: usize = 1024;
+
 impl Server {
     pub fn new(
         app_state: Arc<AppState>,
@@ -661,20 +667,33 @@ impl Server {
             }
         }
 
+        let receipt = request.receipt();
+        let body = request.payload.body.trim().to_string();
+        if body.len() > MAX_MESSAGE_LEN {
+            self.peer
+                .respond_with_error(
+                    receipt,
+                    proto::Error {
+                        message: "message is too long".to_string(),
+                    },
+                )
+                .await?;
+            return Ok(());
+        }
+
         let timestamp = OffsetDateTime::now_utc();
         let message_id = self
             .app_state
             .db
-            .create_channel_message(channel_id, user_id, &request.payload.body, timestamp)
+            .create_channel_message(channel_id, user_id, &body, timestamp)
             .await?
             .to_proto();
-        let receipt = request.receipt();
         let message = proto::ChannelMessageSent {
             channel_id: channel_id.to_proto(),
             message: Some(proto::ChannelMessage {
                 sender_id: user_id.to_proto(),
                 id: message_id,
-                body: request.payload.body,
+                body,
                 timestamp: timestamp.unix_timestamp() as u64,
             }),
         };
@@ -1530,18 +1549,25 @@ mod tests {
             })
             .await;
 
-        channel_a.update(&mut cx_a, |channel, cx| {
-            channel.send_message("oh, hi B.".to_string(), cx).unwrap();
-            channel.send_message("sup".to_string(), cx).unwrap();
-            assert_eq!(
+        channel_a
+            .update(&mut cx_a, |channel, cx| {
                 channel
-                    .pending_messages()
-                    .iter()
-                    .map(|m| &m.body)
-                    .collect::<Vec<_>>(),
-                &["oh, hi B.", "sup"]
-            )
-        });
+                    .send_message("oh, hi B.".to_string(), cx)
+                    .unwrap()
+                    .detach();
+                let task = channel.send_message("sup".to_string(), cx).unwrap();
+                assert_eq!(
+                    channel
+                        .pending_messages()
+                        .iter()
+                        .map(|m| &m.body)
+                        .collect::<Vec<_>>(),
+                    &["oh, hi B.", "sup"]
+                );
+                task
+            })
+            .await
+            .unwrap();
 
         channel_a
             .condition(&cx_a, |channel, _| channel.pending_messages().is_empty())
@@ -1582,6 +1608,59 @@ mod tests {
         }
     }
 
+    #[gpui::test]
+    async fn test_chat_message_validation(mut cx_a: TestAppContext) {
+        cx_a.foreground().forbid_parking();
+
+        let mut server = TestServer::start().await;
+        let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
+
+        let db = &server.app_state.db;
+        let org_id = db.create_org("Test Org", "test-org").await.unwrap();
+        let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
+        db.add_org_member(org_id, user_id_a, false).await.unwrap();
+        db.add_channel_member(channel_id, user_id_a, false)
+            .await
+            .unwrap();
+
+        let user_store_a = Arc::new(UserStore::new(client_a.clone()));
+        let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
+        channels_a
+            .condition(&mut cx_a, |list, _| list.available_channels().is_some())
+            .await;
+        let channel_a = channels_a.update(&mut cx_a, |this, cx| {
+            this.get_channel(channel_id.to_proto(), cx).unwrap()
+        });
+
+        // Leading and trailing whitespace are trimmed.
+        channel_a
+            .update(&mut cx_a, |channel, cx| {
+                channel
+                    .send_message("\n surrounded by whitespace  \n".to_string(), cx)
+                    .unwrap()
+            })
+            .await
+            .unwrap();
+        assert_eq!(
+            db.get_channel_messages(channel_id, 10, None)
+                .await
+                .unwrap()
+                .iter()
+                .map(|m| &m.body)
+                .collect::<Vec<_>>(),
+            &["surrounded by whitespace"]
+        );
+
+        // Messages aren't allowed to be too long.
+        channel_a
+            .update(&mut cx_a, |channel, cx| {
+                let long_body = "this is long.\n".repeat(1024);
+                channel.send_message(long_body, cx).unwrap()
+            })
+            .await
+            .unwrap_err();
+    }
+
     struct TestServer {
         peer: Arc<Peer>,
         app_state: Arc<AppState>,

zed/src/channel.rs 🔗

@@ -224,7 +224,11 @@ impl Channel {
         &self.details.name
     }
 
-    pub fn send_message(&mut self, body: String, cx: &mut ModelContext<Self>) -> Result<()> {
+    pub fn send_message(
+        &mut self,
+        body: String,
+        cx: &mut ModelContext<Self>,
+    ) -> Result<Task<Result<()>>> {
         let channel_id = self.details.id;
         let current_user_id = self.current_user_id()?;
         let local_id = self.next_local_message_id;
@@ -235,41 +239,35 @@ impl Channel {
         });
         let user_store = self.user_store.clone();
         let rpc = self.rpc.clone();
-        cx.spawn(|this, mut cx| {
-            async move {
-                let request = rpc.request(proto::SendChannelMessage { channel_id, body });
-                let response = request.await?;
-                let sender = user_store.get_user(current_user_id).await?;
-
-                this.update(&mut cx, |this, cx| {
-                    if let Ok(i) = this
-                        .pending_messages
-                        .binary_search_by_key(&local_id, |msg| msg.local_id)
-                    {
-                        let body = this.pending_messages.remove(i).body;
-                        this.insert_messages(
-                            SumTree::from_item(
-                                ChannelMessage {
-                                    id: response.message_id,
-                                    timestamp: OffsetDateTime::from_unix_timestamp(
-                                        response.timestamp as i64,
-                                    )?,
-                                    body,
-                                    sender,
-                                },
-                                &(),
-                            ),
-                            cx,
-                        );
-                    }
-                    Ok(())
-                })
-            }
-            .log_err()
-        })
-        .detach();
-        cx.notify();
-        Ok(())
+        Ok(cx.spawn(|this, mut cx| async move {
+            let request = rpc.request(proto::SendChannelMessage { channel_id, body });
+            let response = request.await?;
+            let sender = user_store.get_user(current_user_id).await?;
+
+            this.update(&mut cx, |this, cx| {
+                if let Ok(i) = this
+                    .pending_messages
+                    .binary_search_by_key(&local_id, |msg| msg.local_id)
+                {
+                    let body = this.pending_messages.remove(i).body;
+                    this.insert_messages(
+                        SumTree::from_item(
+                            ChannelMessage {
+                                id: response.message_id,
+                                timestamp: OffsetDateTime::from_unix_timestamp(
+                                    response.timestamp as i64,
+                                )?,
+                                body,
+                                sender,
+                            },
+                            &(),
+                        ),
+                        cx,
+                    );
+                }
+                Ok(())
+            })
+        }))
     }
 
     pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> bool {

zed/src/chat_panel.rs 🔗

@@ -193,9 +193,12 @@ impl ChatPanel {
                 body
             });
 
-            channel
+            if let Some(task) = channel
                 .update(cx, |channel, cx| channel.send_message(body, cx))
-                .log_err();
+                .log_err()
+            {
+                task.detach();
+            }
         }
     }
 

zrpc/proto/zed.proto 🔗

@@ -6,34 +6,35 @@ message Envelope {
     optional uint32 responding_to = 2;
     optional uint32 original_sender_id = 3;
     oneof payload {
-        Ping ping = 4;
-        Pong pong = 5;
-        ShareWorktree share_worktree = 6;
-        ShareWorktreeResponse share_worktree_response = 7;
-        OpenWorktree open_worktree = 8;
-        OpenWorktreeResponse open_worktree_response = 9;
-        UpdateWorktree update_worktree = 10;
-        CloseWorktree close_worktree = 11;
-        OpenBuffer open_buffer = 12;
-        OpenBufferResponse open_buffer_response = 13;
-        CloseBuffer close_buffer = 14;
-        UpdateBuffer update_buffer = 15;
-        SaveBuffer save_buffer = 16;
-        BufferSaved buffer_saved = 17;
-        AddPeer add_peer = 18;
-        RemovePeer remove_peer = 19;
-        GetChannels get_channels = 20;
-        GetChannelsResponse get_channels_response = 21;
-        GetUsers get_users = 22;
-        GetUsersResponse get_users_response = 23;
-        JoinChannel join_channel = 24;
-        JoinChannelResponse join_channel_response = 25;
-        LeaveChannel leave_channel = 26;
-        SendChannelMessage send_channel_message = 27;
-        SendChannelMessageResponse send_channel_message_response = 28;
-        ChannelMessageSent channel_message_sent = 29;
-        GetChannelMessages get_channel_messages = 30;
-        GetChannelMessagesResponse get_channel_messages_response = 31;
+        Error error = 4;
+        Ping ping = 5;
+        Pong pong = 6;
+        ShareWorktree share_worktree = 7;
+        ShareWorktreeResponse share_worktree_response = 8;
+        OpenWorktree open_worktree = 9;
+        OpenWorktreeResponse open_worktree_response = 10;
+        UpdateWorktree update_worktree = 11;
+        CloseWorktree close_worktree = 12;
+        OpenBuffer open_buffer = 13;
+        OpenBufferResponse open_buffer_response = 14;
+        CloseBuffer close_buffer = 15;
+        UpdateBuffer update_buffer = 16;
+        SaveBuffer save_buffer = 17;
+        BufferSaved buffer_saved = 18;
+        AddPeer add_peer = 19;
+        RemovePeer remove_peer = 20;
+        GetChannels get_channels = 21;
+        GetChannelsResponse get_channels_response = 22;
+        GetUsers get_users = 23;
+        GetUsersResponse get_users_response = 24;
+        JoinChannel join_channel = 25;
+        JoinChannelResponse join_channel_response = 26;
+        LeaveChannel leave_channel = 27;
+        SendChannelMessage send_channel_message = 28;
+        SendChannelMessageResponse send_channel_message_response = 29;
+        ChannelMessageSent channel_message_sent = 30;
+        GetChannelMessages get_channel_messages = 31;
+        GetChannelMessagesResponse get_channel_messages_response = 32;
     }
 }
 
@@ -47,6 +48,10 @@ message Pong {
     int32 id = 2;
 }
 
+message Error {
+    string message = 1;
+}
+
 message ShareWorktree {
     Worktree worktree = 1;
 }

zrpc/src/peer.rs 🔗

@@ -238,8 +238,12 @@ impl Peer {
                 .recv()
                 .await
                 .ok_or_else(|| anyhow!("connection was closed"))?;
-            T::Response::from_envelope(response)
-                .ok_or_else(|| anyhow!("received response of the wrong type"))
+            if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
+                Err(anyhow!("request failed").context(error.message.clone()))
+            } else {
+                T::Response::from_envelope(response)
+                    .ok_or_else(|| anyhow!("received response of the wrong type"))
+            }
         }
     }
 
@@ -301,6 +305,25 @@ impl Peer {
         }
     }
 
+    pub fn respond_with_error<T: RequestMessage>(
+        self: &Arc<Self>,
+        receipt: Receipt<T>,
+        response: proto::Error,
+    ) -> impl Future<Output = Result<()>> {
+        let this = self.clone();
+        async move {
+            let mut connection = this.connection(receipt.sender_id).await?;
+            let message_id = connection
+                .next_message_id
+                .fetch_add(1, atomic::Ordering::SeqCst);
+            connection
+                .outgoing_tx
+                .send(response.into_envelope(message_id, Some(receipt.message_id), None))
+                .await?;
+            Ok(())
+        }
+    }
+
     fn connection(
         self: &Arc<Self>,
         connection_id: ConnectionId,

zrpc/src/proto.rs 🔗

@@ -125,6 +125,7 @@ messages!(
     ChannelMessageSent,
     CloseBuffer,
     CloseWorktree,
+    Error,
     GetChannelMessages,
     GetChannelMessagesResponse,
     GetChannels,