Include sanitized message in `SendChannelMessageResponse`

Antonio Scandurra , Nathan Sobo , and Max Brunsfeld created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
Co-Authored-By: Max Brunsfeld <max@zed.dev>

Change summary

server/src/rpc.rs    | 33 ++++++++++++++++-----------------
zed/src/channel.rs   | 32 +++++++-------------------------
zrpc/proto/zed.proto |  5 ++---
3 files changed, 25 insertions(+), 45 deletions(-)

Detailed changes

server/src/rpc.rs 🔗

@@ -695,25 +695,27 @@ impl Server {
             .create_channel_message(channel_id, user_id, &body, timestamp)
             .await?
             .to_proto();
-        let message = proto::ChannelMessageSent {
-            channel_id: channel_id.to_proto(),
-            message: Some(proto::ChannelMessage {
-                sender_id: user_id.to_proto(),
-                id: message_id,
-                body,
-                timestamp: timestamp.unix_timestamp() as u64,
-            }),
+        let message = proto::ChannelMessage {
+            sender_id: user_id.to_proto(),
+            id: message_id,
+            body,
+            timestamp: timestamp.unix_timestamp() as u64,
         };
         broadcast(request.sender_id, connection_ids, |conn_id| {
-            self.peer.send(conn_id, message.clone())
+            self.peer.send(
+                conn_id,
+                proto::ChannelMessageSent {
+                    channel_id: channel_id.to_proto(),
+                    message: Some(message.clone()),
+                },
+            )
         })
         .await?;
         self.peer
             .respond(
                 receipt,
                 proto::SendChannelMessageResponse {
-                    message_id,
-                    timestamp: timestamp.unix_timestamp() as u64,
+                    message: Some(message),
                 },
             )
             .await?;
@@ -1649,12 +1651,9 @@ mod tests {
             .unwrap_err();
 
         // Messages aren't allowed to be blank.
-        channel_a
-            .update(&mut cx_a, |channel, cx| {
-                channel.send_message(String::new(), cx).unwrap()
-            })
-            .await
-            .unwrap_err();
+        channel_a.update(&mut cx_a, |channel, cx| {
+            channel.send_message(String::new(), cx).unwrap_err()
+        });
 
         // Leading and trailing whitespace are trimmed.
         channel_a

zed/src/channel.rs 🔗

@@ -225,7 +225,6 @@ impl Channel {
         }
 
         let channel_id = self.details.id;
-        let current_user_id = self.current_user_id()?;
         let local_id = self.next_local_message_id;
         self.next_local_message_id += 1;
         self.pending_messages.push(PendingChannelMessage {
@@ -237,28 +236,18 @@ impl Channel {
         Ok(cx.spawn(|this, mut cx| async move {
             let request = rpc.request(proto::SendChannelMessage { channel_id, body });
             let response = request.await?;
-            let sender = user_store.get_user(current_user_id).await?;
-
+            let message = ChannelMessage::from_proto(
+                response.message.ok_or_else(|| anyhow!("invalid message"))?,
+                &user_store,
+            )
+            .await?;
             this.update(&mut cx, |this, cx| {
                 if let Ok(i) = this
                     .pending_messages
                     .binary_search_by_key(&local_id, |msg| msg.local_id)
                 {
-                    let body = this.pending_messages.remove(i).body;
-                    this.insert_messages(
-                        SumTree::from_item(
-                            ChannelMessage {
-                                id: response.message_id,
-                                timestamp: OffsetDateTime::from_unix_timestamp(
-                                    response.timestamp as i64,
-                                )?,
-                                body,
-                                sender,
-                            },
-                            &(),
-                        ),
-                        cx,
-                    );
+                    this.pending_messages.remove(i);
+                    this.insert_messages(SumTree::from_item(message, &()), cx);
                 }
                 Ok(())
             })
@@ -320,13 +309,6 @@ impl Channel {
         &self.pending_messages
     }
 
-    fn current_user_id(&self) -> Result<u64> {
-        self.rpc
-            .user_id()
-            .borrow()
-            .ok_or_else(|| anyhow!("not logged in"))
-    }
-
     fn handle_message_sent(
         &mut self,
         message: TypedEnvelope<ChannelMessageSent>,

zrpc/proto/zed.proto 🔗

@@ -158,8 +158,7 @@ message SendChannelMessage {
 }
 
 message SendChannelMessageResponse {
-    uint64 message_id = 1;
-    uint64 timestamp = 2;
+    ChannelMessage message = 1;
 }
 
 message ChannelMessageSent {
@@ -311,4 +310,4 @@ message ChannelMessage {
     string body = 2;
     uint64 timestamp = 3;
     uint64 sender_id = 4;
-}
+}