Create notifications for mentioned users

Max Brunsfeld and Piotr created

Co-authored-by: Piotr <piotr@zed.dev>

Change summary

crates/channel/src/channel_chat.rs               |  13 
crates/collab/src/db.rs                          |   7 
crates/collab/src/db/queries/messages.rs         | 166 ++++++++++++-----
crates/collab/src/db/tests/message_tests.rs      |  27 +-
crates/collab/src/rpc.rs                         |  59 ++++-
crates/collab/src/tests/channel_message_tests.rs |  24 ++
6 files changed, 212 insertions(+), 84 deletions(-)

Detailed changes

crates/channel/src/channel_chat.rs 🔗

@@ -135,7 +135,7 @@ impl ChannelChat {
         &mut self,
         message: MessageParams,
         cx: &mut ModelContext<Self>,
-    ) -> Result<Task<Result<()>>> {
+    ) -> Result<Task<Result<u64>>> {
         if message.text.is_empty() {
             Err(anyhow!("message body can't be empty"))?;
         }
@@ -176,15 +176,12 @@ impl ChannelChat {
             });
             let response = request.await?;
             drop(outgoing_message_guard);
-            let message = ChannelMessage::from_proto(
-                response.message.ok_or_else(|| anyhow!("invalid message"))?,
-                &user_store,
-                &mut cx,
-            )
-            .await?;
+            let response = response.message.ok_or_else(|| anyhow!("invalid message"))?;
+            let id = response.id;
+            let message = ChannelMessage::from_proto(response, &user_store, &mut cx).await?;
             this.update(&mut cx, |this, cx| {
                 this.insert_messages(SumTree::from_item(message, &()), cx);
-                Ok(())
+                Ok(id)
             })
         }))
     }

crates/collab/src/db.rs 🔗

@@ -386,6 +386,13 @@ impl Contact {
 
 pub type NotificationBatch = Vec<(UserId, proto::Notification)>;
 
+pub struct CreatedChannelMessage {
+    pub message_id: MessageId,
+    pub participant_connection_ids: Vec<ConnectionId>,
+    pub channel_members: Vec<UserId>,
+    pub notifications: NotificationBatch,
+}
+
 #[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)]
 pub struct Invite {
     pub email_address: String,

crates/collab/src/db/queries/messages.rs 🔗

@@ -1,4 +1,5 @@
 use super::*;
+use futures::Stream;
 use sea_orm::TryInsertResult;
 use time::OffsetDateTime;
 
@@ -88,61 +89,46 @@ impl Database {
                 condition = condition.add(channel_message::Column::Id.lt(before_message_id));
             }
 
-            let mut rows = channel_message::Entity::find()
+            let rows = channel_message::Entity::find()
                 .filter(condition)
                 .order_by_desc(channel_message::Column::Id)
                 .limit(count as u64)
                 .stream(&*tx)
                 .await?;
 
-            let mut messages = Vec::new();
-            while let Some(row) = rows.next().await {
-                let row = row?;
-                let nonce = row.nonce.as_u64_pair();
-                messages.push(proto::ChannelMessage {
-                    id: row.id.to_proto(),
-                    sender_id: row.sender_id.to_proto(),
-                    body: row.body,
-                    timestamp: row.sent_at.assume_utc().unix_timestamp() as u64,
-                    mentions: vec![],
-                    nonce: Some(proto::Nonce {
-                        upper_half: nonce.0,
-                        lower_half: nonce.1,
-                    }),
-                });
-            }
-            drop(rows);
-            messages.reverse();
+            self.load_channel_messages(rows, &*tx).await
+        })
+        .await
+    }
 
-            let mut mentions = channel_message_mention::Entity::find()
-                .filter(
-                    channel_message_mention::Column::MessageId.is_in(messages.iter().map(|m| m.id)),
-                )
-                .order_by_asc(channel_message_mention::Column::MessageId)
-                .order_by_asc(channel_message_mention::Column::StartOffset)
+    pub async fn get_channel_messages_by_id(
+        &self,
+        user_id: UserId,
+        message_ids: &[MessageId],
+    ) -> Result<Vec<proto::ChannelMessage>> {
+        self.transaction(|tx| async move {
+            let rows = channel_message::Entity::find()
+                .filter(channel_message::Column::Id.is_in(message_ids.iter().copied()))
+                .order_by_desc(channel_message::Column::Id)
                 .stream(&*tx)
                 .await?;
 
-            let mut message_ix = 0;
-            while let Some(mention) = mentions.next().await {
-                let mention = mention?;
-                let message_id = mention.message_id.to_proto();
-                while let Some(message) = messages.get_mut(message_ix) {
-                    if message.id < message_id {
-                        message_ix += 1;
-                    } else {
-                        if message.id == message_id {
-                            message.mentions.push(proto::ChatMention {
-                                range: Some(proto::Range {
-                                    start: mention.start_offset as u64,
-                                    end: mention.end_offset as u64,
-                                }),
-                                user_id: mention.user_id.to_proto(),
-                            });
-                        }
-                        break;
-                    }
-                }
+            let mut channel_ids = HashSet::<ChannelId>::default();
+            let messages = self
+                .load_channel_messages(
+                    rows.map(|row| {
+                        row.map(|row| {
+                            channel_ids.insert(row.channel_id);
+                            row
+                        })
+                    }),
+                    &*tx,
+                )
+                .await?;
+
+            for channel_id in channel_ids {
+                self.check_user_is_channel_member(channel_id, user_id, &*tx)
+                    .await?;
             }
 
             Ok(messages)
@@ -150,6 +136,62 @@ impl Database {
         .await
     }
 
+    async fn load_channel_messages(
+        &self,
+        mut rows: impl Send + Unpin + Stream<Item = Result<channel_message::Model, sea_orm::DbErr>>,
+        tx: &DatabaseTransaction,
+    ) -> Result<Vec<proto::ChannelMessage>> {
+        let mut messages = Vec::new();
+        while let Some(row) = rows.next().await {
+            let row = row?;
+            let nonce = row.nonce.as_u64_pair();
+            messages.push(proto::ChannelMessage {
+                id: row.id.to_proto(),
+                sender_id: row.sender_id.to_proto(),
+                body: row.body,
+                timestamp: row.sent_at.assume_utc().unix_timestamp() as u64,
+                mentions: vec![],
+                nonce: Some(proto::Nonce {
+                    upper_half: nonce.0,
+                    lower_half: nonce.1,
+                }),
+            });
+        }
+        drop(rows);
+        messages.reverse();
+
+        let mut mentions = channel_message_mention::Entity::find()
+            .filter(channel_message_mention::Column::MessageId.is_in(messages.iter().map(|m| m.id)))
+            .order_by_asc(channel_message_mention::Column::MessageId)
+            .order_by_asc(channel_message_mention::Column::StartOffset)
+            .stream(&*tx)
+            .await?;
+
+        let mut message_ix = 0;
+        while let Some(mention) = mentions.next().await {
+            let mention = mention?;
+            let message_id = mention.message_id.to_proto();
+            while let Some(message) = messages.get_mut(message_ix) {
+                if message.id < message_id {
+                    message_ix += 1;
+                } else {
+                    if message.id == message_id {
+                        message.mentions.push(proto::ChatMention {
+                            range: Some(proto::Range {
+                                start: mention.start_offset as u64,
+                                end: mention.end_offset as u64,
+                            }),
+                            user_id: mention.user_id.to_proto(),
+                        });
+                    }
+                    break;
+                }
+            }
+        }
+
+        Ok(messages)
+    }
+
     pub async fn create_channel_message(
         &self,
         channel_id: ChannelId,
@@ -158,7 +200,7 @@ impl Database {
         mentions: &[proto::ChatMention],
         timestamp: OffsetDateTime,
         nonce: u128,
-    ) -> Result<(MessageId, Vec<ConnectionId>, Vec<UserId>)> {
+    ) -> Result<CreatedChannelMessage> {
         self.transaction(|tx| async move {
             let mut rows = channel_chat_participant::Entity::find()
                 .filter(channel_chat_participant::Column::ChannelId.eq(channel_id))
@@ -206,10 +248,13 @@ impl Database {
             .await?;
 
             let message_id;
+            let mut notifications = Vec::new();
             match result {
                 TryInsertResult::Inserted(result) => {
                     message_id = result.last_insert_id;
-                    let models = mentions
+                    let mentioned_user_ids =
+                        mentions.iter().map(|m| m.user_id).collect::<HashSet<_>>();
+                    let mentions = mentions
                         .iter()
                         .filter_map(|mention| {
                             let range = mention.range.as_ref()?;
@@ -226,12 +271,28 @@ impl Database {
                             })
                         })
                         .collect::<Vec<_>>();
-                    if !models.is_empty() {
-                        channel_message_mention::Entity::insert_many(models)
+                    if !mentions.is_empty() {
+                        channel_message_mention::Entity::insert_many(mentions)
                             .exec(&*tx)
                             .await?;
                     }
 
+                    for mentioned_user in mentioned_user_ids {
+                        notifications.extend(
+                            self.create_notification(
+                                UserId::from_proto(mentioned_user),
+                                rpc::Notification::ChannelMessageMention {
+                                    message_id: message_id.to_proto(),
+                                    sender_id: user_id.to_proto(),
+                                    channel_id: channel_id.to_proto(),
+                                },
+                                false,
+                                &*tx,
+                            )
+                            .await?,
+                        );
+                    }
+
                     self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx)
                         .await?;
                 }
@@ -250,7 +311,12 @@ impl Database {
                 .await?;
             channel_members.retain(|member| !participant_user_ids.contains(member));
 
-            Ok((message_id, participant_connection_ids, channel_members))
+            Ok(CreatedChannelMessage {
+                message_id,
+                participant_connection_ids,
+                channel_members,
+                notifications,
+            })
         })
         .await
     }

crates/collab/src/db/tests/message_tests.rs 🔗

@@ -35,7 +35,7 @@ async fn test_channel_message_retrieval(db: &Arc<Database>) {
             )
             .await
             .unwrap()
-            .0
+            .message_id
             .to_proto(),
         );
     }
@@ -109,7 +109,7 @@ async fn test_channel_message_nonces(db: &Arc<Database>) {
         )
         .await
         .unwrap()
-        .0;
+        .message_id;
     let id2 = db
         .create_channel_message(
             channel,
@@ -121,7 +121,7 @@ async fn test_channel_message_nonces(db: &Arc<Database>) {
         )
         .await
         .unwrap()
-        .0;
+        .message_id;
     let id3 = db
         .create_channel_message(
             channel,
@@ -133,7 +133,7 @@ async fn test_channel_message_nonces(db: &Arc<Database>) {
         )
         .await
         .unwrap()
-        .0;
+        .message_id;
     let id4 = db
         .create_channel_message(
             channel,
@@ -145,7 +145,7 @@ async fn test_channel_message_nonces(db: &Arc<Database>) {
         )
         .await
         .unwrap()
-        .0;
+        .message_id;
 
     // As a different user, reuse one of the same nonces. This request succeeds
     // and returns a different id.
@@ -160,7 +160,7 @@ async fn test_channel_message_nonces(db: &Arc<Database>) {
         )
         .await
         .unwrap()
-        .0;
+        .message_id;
 
     assert_ne!(id1, id2);
     assert_eq!(id1, id3);
@@ -235,24 +235,27 @@ async fn test_unseen_channel_messages(db: &Arc<Database>) {
         .await
         .unwrap();
 
-    let (second_message, _, _) = db
+    let second_message = db
         .create_channel_message(channel_1, user, "1_2", &[], OffsetDateTime::now_utc(), 2)
         .await
-        .unwrap();
+        .unwrap()
+        .message_id;
 
-    let (third_message, _, _) = db
+    let third_message = db
         .create_channel_message(channel_1, user, "1_3", &[], OffsetDateTime::now_utc(), 3)
         .await
-        .unwrap();
+        .unwrap()
+        .message_id;
 
     db.join_channel_chat(channel_2, user_connection_id, user)
         .await
         .unwrap();
 
-    let (fourth_message, _, _) = db
+    let fourth_message = db
         .create_channel_message(channel_2, user, "2_1", &[], OffsetDateTime::now_utc(), 4)
         .await
-        .unwrap();
+        .unwrap()
+        .message_id;
 
     // Check that observer has new messages
     let unseen_messages = db

crates/collab/src/rpc.rs 🔗

@@ -3,8 +3,8 @@ mod connection_pool;
 use crate::{
     auth,
     db::{
-        self, BufferId, ChannelId, ChannelVisibility, ChannelsForUser, Database, MessageId,
-        ProjectId, RoomId, ServerId, User, UserId,
+        self, BufferId, ChannelId, ChannelVisibility, ChannelsForUser, CreatedChannelMessage,
+        Database, MessageId, ProjectId, RoomId, ServerId, User, UserId,
     },
     executor::Executor,
     AppState, Result,
@@ -271,6 +271,7 @@ impl Server {
             .add_request_handler(send_channel_message)
             .add_request_handler(remove_channel_message)
             .add_request_handler(get_channel_messages)
+            .add_request_handler(get_channel_messages_by_id)
             .add_request_handler(get_notifications)
             .add_request_handler(link_channel)
             .add_request_handler(unlink_channel)
@@ -2969,7 +2970,12 @@ async fn send_channel_message(
         .ok_or_else(|| anyhow!("nonce can't be blank"))?;
 
     let channel_id = ChannelId::from_proto(request.channel_id);
-    let (message_id, connection_ids, non_participants) = session
+    let CreatedChannelMessage {
+        message_id,
+        participant_connection_ids,
+        channel_members,
+        notifications,
+    } = session
         .db()
         .await
         .create_channel_message(
@@ -2989,15 +2995,19 @@ async fn send_channel_message(
         timestamp: timestamp.unix_timestamp() as u64,
         nonce: Some(nonce),
     };
-    broadcast(Some(session.connection_id), connection_ids, |connection| {
-        session.peer.send(
-            connection,
-            proto::ChannelMessageSent {
-                channel_id: channel_id.to_proto(),
-                message: Some(message.clone()),
-            },
-        )
-    });
+    broadcast(
+        Some(session.connection_id),
+        participant_connection_ids,
+        |connection| {
+            session.peer.send(
+                connection,
+                proto::ChannelMessageSent {
+                    channel_id: channel_id.to_proto(),
+                    message: Some(message.clone()),
+                },
+            )
+        },
+    );
     response.send(proto::SendChannelMessageResponse {
         message: Some(message),
     })?;
@@ -3005,7 +3015,7 @@ async fn send_channel_message(
     let pool = &*session.connection_pool().await;
     broadcast(
         None,
-        non_participants
+        channel_members
             .iter()
             .flat_map(|user_id| pool.user_connection_ids(*user_id)),
         |peer_id| {
@@ -3021,6 +3031,7 @@ async fn send_channel_message(
             )
         },
     );
+    send_notifications(pool, &session.peer, notifications);
 
     Ok(())
 }
@@ -3129,6 +3140,28 @@ async fn get_channel_messages(
     Ok(())
 }
 
+async fn get_channel_messages_by_id(
+    request: proto::GetChannelMessagesById,
+    response: Response<proto::GetChannelMessagesById>,
+    session: Session,
+) -> Result<()> {
+    let message_ids = request
+        .message_ids
+        .iter()
+        .map(|id| MessageId::from_proto(*id))
+        .collect::<Vec<_>>();
+    let messages = session
+        .db()
+        .await
+        .get_channel_messages_by_id(session.user_id, &message_ids)
+        .await?;
+    response.send(proto::GetChannelMessagesResponse {
+        done: messages.len() < MESSAGE_COUNT_PER_PAGE,
+        messages,
+    })?;
+    Ok(())
+}
+
 async fn get_notifications(
     request: proto::GetNotifications,
     response: Response<proto::GetNotifications>,

crates/collab/src/tests/channel_message_tests.rs 🔗

@@ -2,6 +2,7 @@ use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer};
 use channel::{ChannelChat, ChannelMessageId, MessageParams};
 use collab_ui::chat_panel::ChatPanel;
 use gpui::{executor::Deterministic, BorrowAppContext, ModelHandle, TestAppContext};
+use rpc::Notification;
 use std::sync::Arc;
 use workspace::dock::Panel;
 
@@ -38,7 +39,7 @@ async fn test_basic_channel_messages(
         .await
         .unwrap();
 
-    channel_chat_a
+    let message_id = channel_chat_a
         .update(cx_a, |c, cx| {
             c.send_message(
                 MessageParams {
@@ -91,6 +92,27 @@ async fn test_basic_channel_messages(
             );
         });
     }
+
+    client_c.notification_store().update(cx_c, |store, _| {
+        assert_eq!(store.notification_count(), 2);
+        assert_eq!(store.unread_notification_count(), 1);
+        assert_eq!(
+            store.notification_at(0).unwrap().notification,
+            Notification::ChannelMessageMention {
+                message_id,
+                sender_id: client_a.id(),
+                channel_id,
+            }
+        );
+        assert_eq!(
+            store.notification_at(1).unwrap().notification,
+            Notification::ChannelInvitation {
+                channel_id,
+                channel_name: "the-channel".to_string(),
+                inviter_id: client_a.id()
+            }
+        );
+    });
 }
 
 #[gpui::test]