Detailed changes
@@ -77,8 +77,14 @@ struct Channel {
connection_ids: HashSet<ConnectionId>,
}
+#[cfg(debug_assertions)]
+const MESSAGE_COUNT_PER_PAGE: usize = 10;
+
+#[cfg(not(debug_assertions))]
const MESSAGE_COUNT_PER_PAGE: usize = 50;
+const MAX_MESSAGE_LEN: usize = 1024;
+
impl Server {
pub fn new(
app_state: Arc<AppState>,
@@ -661,20 +667,33 @@ impl Server {
}
}
+ let receipt = request.receipt();
+ let body = request.payload.body.trim().to_string();
+ if body.len() > MAX_MESSAGE_LEN {
+ self.peer
+ .respond_with_error(
+ receipt,
+ proto::Error {
+ message: "message is too long".to_string(),
+ },
+ )
+ .await?;
+ return Ok(());
+ }
+
let timestamp = OffsetDateTime::now_utc();
let message_id = self
.app_state
.db
- .create_channel_message(channel_id, user_id, &request.payload.body, timestamp)
+ .create_channel_message(channel_id, user_id, &body, timestamp)
.await?
.to_proto();
- let receipt = request.receipt();
let message = proto::ChannelMessageSent {
channel_id: channel_id.to_proto(),
message: Some(proto::ChannelMessage {
sender_id: user_id.to_proto(),
id: message_id,
- body: request.payload.body,
+ body,
timestamp: timestamp.unix_timestamp() as u64,
}),
};
@@ -1530,18 +1549,25 @@ mod tests {
})
.await;
- channel_a.update(&mut cx_a, |channel, cx| {
- channel.send_message("oh, hi B.".to_string(), cx).unwrap();
- channel.send_message("sup".to_string(), cx).unwrap();
- assert_eq!(
+ channel_a
+ .update(&mut cx_a, |channel, cx| {
channel
- .pending_messages()
- .iter()
- .map(|m| &m.body)
- .collect::<Vec<_>>(),
- &["oh, hi B.", "sup"]
- )
- });
+ .send_message("oh, hi B.".to_string(), cx)
+ .unwrap()
+ .detach();
+ let task = channel.send_message("sup".to_string(), cx).unwrap();
+ assert_eq!(
+ channel
+ .pending_messages()
+ .iter()
+ .map(|m| &m.body)
+ .collect::<Vec<_>>(),
+ &["oh, hi B.", "sup"]
+ );
+ task
+ })
+ .await
+ .unwrap();
channel_a
.condition(&cx_a, |channel, _| channel.pending_messages().is_empty())
@@ -1582,6 +1608,59 @@ mod tests {
}
}
+ #[gpui::test]
+ async fn test_chat_message_validation(mut cx_a: TestAppContext) {
+ cx_a.foreground().forbid_parking();
+
+ let mut server = TestServer::start().await;
+ let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
+
+ let db = &server.app_state.db;
+ let org_id = db.create_org("Test Org", "test-org").await.unwrap();
+ let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
+ db.add_org_member(org_id, user_id_a, false).await.unwrap();
+ db.add_channel_member(channel_id, user_id_a, false)
+ .await
+ .unwrap();
+
+ let user_store_a = Arc::new(UserStore::new(client_a.clone()));
+ let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
+ channels_a
+ .condition(&mut cx_a, |list, _| list.available_channels().is_some())
+ .await;
+ let channel_a = channels_a.update(&mut cx_a, |this, cx| {
+ this.get_channel(channel_id.to_proto(), cx).unwrap()
+ });
+
+ // Leading and trailing whitespace are trimmed.
+ channel_a
+ .update(&mut cx_a, |channel, cx| {
+ channel
+ .send_message("\n surrounded by whitespace \n".to_string(), cx)
+ .unwrap()
+ })
+ .await
+ .unwrap();
+ assert_eq!(
+ db.get_channel_messages(channel_id, 10, None)
+ .await
+ .unwrap()
+ .iter()
+ .map(|m| &m.body)
+ .collect::<Vec<_>>(),
+ &["surrounded by whitespace"]
+ );
+
+ // Messages aren't allowed to be too long.
+ channel_a
+ .update(&mut cx_a, |channel, cx| {
+ let long_body = "this is long.\n".repeat(1024);
+ channel.send_message(long_body, cx).unwrap()
+ })
+ .await
+ .unwrap_err();
+ }
+
struct TestServer {
peer: Arc<Peer>,
app_state: Arc<AppState>,
@@ -224,7 +224,11 @@ impl Channel {
&self.details.name
}
- pub fn send_message(&mut self, body: String, cx: &mut ModelContext<Self>) -> Result<()> {
+ pub fn send_message(
+ &mut self,
+ body: String,
+ cx: &mut ModelContext<Self>,
+ ) -> Result<Task<Result<()>>> {
let channel_id = self.details.id;
let current_user_id = self.current_user_id()?;
let local_id = self.next_local_message_id;
@@ -235,41 +239,35 @@ impl Channel {
});
let user_store = self.user_store.clone();
let rpc = self.rpc.clone();
- cx.spawn(|this, mut cx| {
- async move {
- let request = rpc.request(proto::SendChannelMessage { channel_id, body });
- let response = request.await?;
- let sender = user_store.get_user(current_user_id).await?;
-
- this.update(&mut cx, |this, cx| {
- if let Ok(i) = this
- .pending_messages
- .binary_search_by_key(&local_id, |msg| msg.local_id)
- {
- let body = this.pending_messages.remove(i).body;
- this.insert_messages(
- SumTree::from_item(
- ChannelMessage {
- id: response.message_id,
- timestamp: OffsetDateTime::from_unix_timestamp(
- response.timestamp as i64,
- )?,
- body,
- sender,
- },
- &(),
- ),
- cx,
- );
- }
- Ok(())
- })
- }
- .log_err()
- })
- .detach();
- cx.notify();
- Ok(())
+ Ok(cx.spawn(|this, mut cx| async move {
+ let request = rpc.request(proto::SendChannelMessage { channel_id, body });
+ let response = request.await?;
+ let sender = user_store.get_user(current_user_id).await?;
+
+ this.update(&mut cx, |this, cx| {
+ if let Ok(i) = this
+ .pending_messages
+ .binary_search_by_key(&local_id, |msg| msg.local_id)
+ {
+ let body = this.pending_messages.remove(i).body;
+ this.insert_messages(
+ SumTree::from_item(
+ ChannelMessage {
+ id: response.message_id,
+ timestamp: OffsetDateTime::from_unix_timestamp(
+ response.timestamp as i64,
+ )?,
+ body,
+ sender,
+ },
+ &(),
+ ),
+ cx,
+ );
+ }
+ Ok(())
+ })
+ }))
}
pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> bool {
@@ -193,9 +193,12 @@ impl ChatPanel {
body
});
- channel
+ if let Some(task) = channel
.update(cx, |channel, cx| channel.send_message(body, cx))
- .log_err();
+ .log_err()
+ {
+ task.detach();
+ }
}
}
@@ -6,34 +6,35 @@ message Envelope {
optional uint32 responding_to = 2;
optional uint32 original_sender_id = 3;
oneof payload {
- Ping ping = 4;
- Pong pong = 5;
- ShareWorktree share_worktree = 6;
- ShareWorktreeResponse share_worktree_response = 7;
- OpenWorktree open_worktree = 8;
- OpenWorktreeResponse open_worktree_response = 9;
- UpdateWorktree update_worktree = 10;
- CloseWorktree close_worktree = 11;
- OpenBuffer open_buffer = 12;
- OpenBufferResponse open_buffer_response = 13;
- CloseBuffer close_buffer = 14;
- UpdateBuffer update_buffer = 15;
- SaveBuffer save_buffer = 16;
- BufferSaved buffer_saved = 17;
- AddPeer add_peer = 18;
- RemovePeer remove_peer = 19;
- GetChannels get_channels = 20;
- GetChannelsResponse get_channels_response = 21;
- GetUsers get_users = 22;
- GetUsersResponse get_users_response = 23;
- JoinChannel join_channel = 24;
- JoinChannelResponse join_channel_response = 25;
- LeaveChannel leave_channel = 26;
- SendChannelMessage send_channel_message = 27;
- SendChannelMessageResponse send_channel_message_response = 28;
- ChannelMessageSent channel_message_sent = 29;
- GetChannelMessages get_channel_messages = 30;
- GetChannelMessagesResponse get_channel_messages_response = 31;
+ Error error = 4;
+ Ping ping = 5;
+ Pong pong = 6;
+ ShareWorktree share_worktree = 7;
+ ShareWorktreeResponse share_worktree_response = 8;
+ OpenWorktree open_worktree = 9;
+ OpenWorktreeResponse open_worktree_response = 10;
+ UpdateWorktree update_worktree = 11;
+ CloseWorktree close_worktree = 12;
+ OpenBuffer open_buffer = 13;
+ OpenBufferResponse open_buffer_response = 14;
+ CloseBuffer close_buffer = 15;
+ UpdateBuffer update_buffer = 16;
+ SaveBuffer save_buffer = 17;
+ BufferSaved buffer_saved = 18;
+ AddPeer add_peer = 19;
+ RemovePeer remove_peer = 20;
+ GetChannels get_channels = 21;
+ GetChannelsResponse get_channels_response = 22;
+ GetUsers get_users = 23;
+ GetUsersResponse get_users_response = 24;
+ JoinChannel join_channel = 25;
+ JoinChannelResponse join_channel_response = 26;
+ LeaveChannel leave_channel = 27;
+ SendChannelMessage send_channel_message = 28;
+ SendChannelMessageResponse send_channel_message_response = 29;
+ ChannelMessageSent channel_message_sent = 30;
+ GetChannelMessages get_channel_messages = 31;
+ GetChannelMessagesResponse get_channel_messages_response = 32;
}
}
@@ -47,6 +48,10 @@ message Pong {
int32 id = 2;
}
+message Error {
+ string message = 1;
+}
+
message ShareWorktree {
Worktree worktree = 1;
}
@@ -238,8 +238,12 @@ impl Peer {
.recv()
.await
.ok_or_else(|| anyhow!("connection was closed"))?;
- T::Response::from_envelope(response)
- .ok_or_else(|| anyhow!("received response of the wrong type"))
+ if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
+ Err(anyhow!("request failed").context(error.message.clone()))
+ } else {
+ T::Response::from_envelope(response)
+ .ok_or_else(|| anyhow!("received response of the wrong type"))
+ }
}
}
@@ -301,6 +305,25 @@ impl Peer {
}
}
+ pub fn respond_with_error<T: RequestMessage>(
+ self: &Arc<Self>,
+ receipt: Receipt<T>,
+ response: proto::Error,
+ ) -> impl Future<Output = Result<()>> {
+ let this = self.clone();
+ async move {
+ let mut connection = this.connection(receipt.sender_id).await?;
+ let message_id = connection
+ .next_message_id
+ .fetch_add(1, atomic::Ordering::SeqCst);
+ connection
+ .outgoing_tx
+ .send(response.into_envelope(message_id, Some(receipt.message_id), None))
+ .await?;
+ Ok(())
+ }
+ }
+
fn connection(
self: &Arc<Self>,
connection_id: ConnectionId,
@@ -125,6 +125,7 @@ messages!(
ChannelMessageSent,
CloseBuffer,
CloseWorktree,
+ Error,
GetChannelMessages,
GetChannelMessagesResponse,
GetChannels,