Remove contact notifications when cancelling a contact request

Max Brunsfeld created

Change summary

crates/channel/src/channel_store.rs                        |   3 
crates/collab/src/db/queries/contacts.rs                   |  22 +
crates/collab/src/db/queries/notifications.rs              |  45 ++
crates/collab/src/rpc.rs                                   |  11 
crates/collab_ui/src/notification_panel.rs                 |  66 ++-
crates/collab_ui/src/notifications/contact_notification.rs |  16 
crates/notifications/src/notification_store.rs             | 128 +++++--
crates/rpc/proto/zed.proto                                 |   7 
crates/rpc/src/notification.rs                             |  14 
crates/rpc/src/proto.rs                                    |   1 
10 files changed, 225 insertions(+), 88 deletions(-)

Detailed changes

crates/channel/src/channel_store.rs 🔗

@@ -127,9 +127,6 @@ impl ChannelStore {
                         this.update(&mut cx, |this, cx| this.handle_disconnect(true, cx));
                     }
                 }
-                if status.is_connected() {
-                } else {
-                }
             }
             Some(())
         });

crates/collab/src/db/queries/contacts.rs 🔗

@@ -185,7 +185,11 @@ impl Database {
     ///
     /// * `requester_id` - The user that initiates this request
     /// * `responder_id` - The user that will be removed
-    pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<bool> {
+    pub async fn remove_contact(
+        &self,
+        requester_id: UserId,
+        responder_id: UserId,
+    ) -> Result<(bool, Option<NotificationId>)> {
         self.transaction(|tx| async move {
             let (id_a, id_b) = if responder_id < requester_id {
                 (responder_id, requester_id)
@@ -204,7 +208,21 @@ impl Database {
                 .ok_or_else(|| anyhow!("no such contact"))?;
 
             contact::Entity::delete_by_id(contact.id).exec(&*tx).await?;
-            Ok(contact.accepted)
+
+            let mut deleted_notification_id = None;
+            if !contact.accepted {
+                deleted_notification_id = self
+                    .delete_notification(
+                        responder_id,
+                        rpc::Notification::ContactRequest {
+                            actor_id: requester_id.to_proto(),
+                        },
+                        &*tx,
+                    )
+                    .await?;
+            }
+
+            Ok((contact.accepted, deleted_notification_id))
         })
         .await
     }

crates/collab/src/db/queries/notifications.rs 🔗

@@ -3,12 +3,12 @@ use rpc::Notification;
 
 impl Database {
     pub async fn initialize_notification_enum(&mut self) -> Result<()> {
-        notification_kind::Entity::insert_many(Notification::all_kinds().iter().map(|kind| {
-            notification_kind::ActiveModel {
+        notification_kind::Entity::insert_many(Notification::all_variant_names().iter().map(
+            |kind| notification_kind::ActiveModel {
                 name: ActiveValue::Set(kind.to_string()),
                 ..Default::default()
-            }
-        }))
+            },
+        ))
         .on_conflict(OnConflict::new().do_nothing().to_owned())
         .exec_without_returning(&self.pool)
         .await?;
@@ -19,6 +19,12 @@ impl Database {
             self.notification_kinds_by_name.insert(row.name, row.id);
         }
 
+        for name in Notification::all_variant_names() {
+            if let Some(id) = self.notification_kinds_by_name.get(*name).copied() {
+                self.notification_kinds_by_id.insert(id, name);
+            }
+        }
+
         Ok(())
     }
 
@@ -46,6 +52,7 @@ impl Database {
             while let Some(row) = rows.next().await {
                 let row = row?;
                 let Some(kind) = self.notification_kinds_by_id.get(&row.kind) else {
+                    log::warn!("unknown notification kind {:?}", row.kind);
                     continue;
                 };
                 result.push(proto::Notification {
@@ -96,4 +103,34 @@ impl Database {
             actor_id: notification.actor_id,
         })
     }
+
+    pub async fn delete_notification(
+        &self,
+        recipient_id: UserId,
+        notification: Notification,
+        tx: &DatabaseTransaction,
+    ) -> Result<Option<NotificationId>> {
+        let notification = notification.to_any();
+        let kind = *self
+            .notification_kinds_by_name
+            .get(notification.kind.as_ref())
+            .ok_or_else(|| anyhow!("invalid notification kind {:?}", notification.kind))?;
+        let actor_id = notification.actor_id.map(|id| UserId::from_proto(id));
+        let notification = notification::Entity::find()
+            .filter(
+                Condition::all()
+                    .add(notification::Column::RecipientId.eq(recipient_id))
+                    .add(notification::Column::Kind.eq(kind))
+                    .add(notification::Column::ActorId.eq(actor_id))
+                    .add(notification::Column::Content.eq(notification.content)),
+            )
+            .one(tx)
+            .await?;
+        if let Some(notification) = &notification {
+            notification::Entity::delete_by_id(notification.id)
+                .exec(tx)
+                .await?;
+        }
+        Ok(notification.map(|notification| notification.id))
+    }
 }

crates/collab/src/rpc.rs 🔗

@@ -2177,7 +2177,8 @@ async fn remove_contact(
     let requester_id = session.user_id;
     let responder_id = UserId::from_proto(request.user_id);
     let db = session.db().await;
-    let contact_accepted = db.remove_contact(requester_id, responder_id).await?;
+    let (contact_accepted, deleted_notification_id) =
+        db.remove_contact(requester_id, responder_id).await?;
 
     let pool = session.connection_pool().await;
     // Update outgoing contact requests of requester
@@ -2204,6 +2205,14 @@ async fn remove_contact(
     }
     for connection_id in pool.user_connection_ids(responder_id) {
         session.peer.send(connection_id, update.clone())?;
+        if let Some(notification_id) = deleted_notification_id {
+            session.peer.send(
+                connection_id,
+                proto::DeleteNotification {
+                    notification_id: notification_id.to_proto(),
+                },
+            )?;
+        }
     }
 
     response.send(proto::Ack {})?;

crates/collab_ui/src/notification_panel.rs 🔗

@@ -301,6 +301,8 @@ impl NotificationPanel {
         cx: &mut ViewContext<Self>,
     ) {
         match event {
+            NotificationEvent::NewNotification { entry } => self.add_toast(entry, cx),
+            NotificationEvent::NotificationRemoved { entry } => self.remove_toast(entry, cx),
             NotificationEvent::NotificationsUpdated {
                 old_range,
                 new_count,
@@ -308,31 +310,49 @@ impl NotificationPanel {
                 self.notification_list.splice(old_range.clone(), *new_count);
                 cx.notify();
             }
-            NotificationEvent::NewNotification { entry } => match entry.notification {
-                Notification::ContactRequest { actor_id }
-                | Notification::ContactRequestAccepted { actor_id } => {
-                    let user_store = self.user_store.clone();
-                    let Some(user) = user_store.read(cx).get_cached_user(actor_id) else {
-                        return;
-                    };
-                    self.workspace
-                        .update(cx, |workspace, cx| {
-                            workspace.show_notification(actor_id as usize, cx, |cx| {
-                                cx.add_view(|cx| {
-                                    ContactNotification::new(
-                                        user.clone(),
-                                        entry.notification.clone(),
-                                        user_store,
-                                        cx,
-                                    )
-                                })
+        }
+    }
+
+    fn add_toast(&mut self, entry: &NotificationEntry, cx: &mut ViewContext<Self>) {
+        let id = entry.id as usize;
+        match entry.notification {
+            Notification::ContactRequest { actor_id }
+            | Notification::ContactRequestAccepted { actor_id } => {
+                let user_store = self.user_store.clone();
+                let Some(user) = user_store.read(cx).get_cached_user(actor_id) else {
+                    return;
+                };
+                self.workspace
+                    .update(cx, |workspace, cx| {
+                        workspace.show_notification(id, cx, |cx| {
+                            cx.add_view(|_| {
+                                ContactNotification::new(
+                                    user,
+                                    entry.notification.clone(),
+                                    user_store,
+                                )
                             })
                         })
-                        .ok();
-                }
-                Notification::ChannelInvitation { .. } => {}
-                Notification::ChannelMessageMention { .. } => {}
-            },
+                    })
+                    .ok();
+            }
+            Notification::ChannelInvitation { .. } => {}
+            Notification::ChannelMessageMention { .. } => {}
+        }
+    }
+
+    fn remove_toast(&mut self, entry: &NotificationEntry, cx: &mut ViewContext<Self>) {
+        let id = entry.id as usize;
+        match entry.notification {
+            Notification::ContactRequest { .. } | Notification::ContactRequestAccepted { .. } => {
+                self.workspace
+                    .update(cx, |workspace, cx| {
+                        workspace.dismiss_notification::<ContactNotification>(id, cx)
+                    })
+                    .ok();
+            }
+            Notification::ChannelInvitation { .. } => {}
+            Notification::ChannelMessageMention { .. } => {}
         }
     }
 }

crates/collab_ui/src/notifications/contact_notification.rs 🔗

@@ -1,5 +1,5 @@
 use crate::notifications::render_user_notification;
-use client::{ContactEventKind, User, UserStore};
+use client::{User, UserStore};
 use gpui::{elements::*, Entity, ModelHandle, View, ViewContext};
 use std::sync::Arc;
 use workspace::notifications::Notification;
@@ -79,21 +79,7 @@ impl ContactNotification {
         user: Arc<User>,
         notification: rpc::Notification,
         user_store: ModelHandle<UserStore>,
-        cx: &mut ViewContext<Self>,
     ) -> Self {
-        cx.subscribe(&user_store, move |this, _, event, cx| {
-            if let client::Event::Contact {
-                kind: ContactEventKind::Cancelled,
-                user,
-            } = event
-            {
-                if user.id == this.user.id {
-                    cx.emit(Event::Dismiss);
-                }
-            }
-        })
-        .detach();
-
         Self {
             user,
             notification,

crates/notifications/src/notification_store.rs 🔗

@@ -2,11 +2,13 @@ use anyhow::Result;
 use channel::{ChannelMessage, ChannelMessageId, ChannelStore};
 use client::{Client, UserStore};
 use collections::HashMap;
+use db::smol::stream::StreamExt;
 use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task};
 use rpc::{proto, AnyNotification, Notification, TypedEnvelope};
 use std::{ops::Range, sync::Arc};
 use sum_tree::{Bias, SumTree};
 use time::OffsetDateTime;
+use util::ResultExt;
 
 pub fn init(client: Arc<Client>, user_store: ModelHandle<UserStore>, cx: &mut AppContext) {
     let notification_store = cx.add_model(|cx| NotificationStore::new(client, user_store, cx));
@@ -19,6 +21,7 @@ pub struct NotificationStore {
     channel_messages: HashMap<u64, ChannelMessage>,
     channel_store: ModelHandle<ChannelStore>,
     notifications: SumTree<NotificationEntry>,
+    _watch_connection_status: Task<Option<()>>,
     _subscriptions: Vec<client::Subscription>,
 }
 
@@ -30,6 +33,9 @@ pub enum NotificationEvent {
     NewNotification {
         entry: NotificationEntry,
     },
+    NotificationRemoved {
+        entry: NotificationEntry,
+    },
 }
 
 #[derive(Debug, PartialEq, Eq, Clone)]
@@ -66,19 +72,34 @@ impl NotificationStore {
         user_store: ModelHandle<UserStore>,
         cx: &mut ModelContext<Self>,
     ) -> Self {
-        let this = Self {
+        let mut connection_status = client.status();
+        let watch_connection_status = cx.spawn_weak(|this, mut cx| async move {
+            while let Some(status) = connection_status.next().await {
+                let this = this.upgrade(&cx)?;
+                match status {
+                    client::Status::Connected { .. } => {
+                        this.update(&mut cx, |this, cx| this.handle_connect(cx))
+                            .await
+                            .log_err()?;
+                    }
+                    _ => this.update(&mut cx, |this, cx| this.handle_disconnect(cx)),
+                }
+            }
+            Some(())
+        });
+
+        Self {
             channel_store: ChannelStore::global(cx),
             notifications: Default::default(),
             channel_messages: Default::default(),
+            _watch_connection_status: watch_connection_status,
             _subscriptions: vec![
-                client.add_message_handler(cx.handle(), Self::handle_new_notification)
+                client.add_message_handler(cx.handle(), Self::handle_new_notification),
+                client.add_message_handler(cx.handle(), Self::handle_delete_notification),
             ],
             user_store,
             client,
-        };
-
-        this.load_more_notifications(cx).detach();
-        this
+        }
     }
 
     pub fn notification_count(&self) -> usize {
@@ -110,6 +131,16 @@ impl NotificationStore {
         })
     }
 
+    fn handle_connect(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
+        self.notifications = Default::default();
+        self.channel_messages = Default::default();
+        self.load_more_notifications(cx)
+    }
+
+    fn handle_disconnect(&mut self, cx: &mut ModelContext<Self>) {
+        cx.notify()
+    }
+
     async fn handle_new_notification(
         this: ModelHandle<Self>,
         envelope: TypedEnvelope<proto::NewNotification>,
@@ -125,6 +156,18 @@ impl NotificationStore {
         .await
     }
 
+    async fn handle_delete_notification(
+        this: ModelHandle<Self>,
+        envelope: TypedEnvelope<proto::DeleteNotification>,
+        _: Arc<Client>,
+        mut cx: AsyncAppContext,
+    ) -> Result<()> {
+        this.update(&mut cx, |this, cx| {
+            this.splice_notifications([(envelope.payload.notification_id, None)], false, cx);
+            Ok(())
+        })
+    }
+
     async fn add_notifications(
         this: ModelHandle<Self>,
         is_new: bool,
@@ -205,26 +248,47 @@ impl NotificationStore {
                     }
                 }));
 
-            let mut cursor = this.notifications.cursor::<(NotificationId, Count)>();
-            let mut new_notifications = SumTree::new();
-            let mut old_range = 0..0;
-            for (i, notification) in notifications.into_iter().enumerate() {
-                new_notifications.append(
-                    cursor.slice(&NotificationId(notification.id), Bias::Left, &()),
-                    &(),
-                );
-
-                if i == 0 {
-                    old_range.start = cursor.start().1 .0;
-                }
+            this.splice_notifications(
+                notifications
+                    .into_iter()
+                    .map(|notification| (notification.id, Some(notification))),
+                is_new,
+                cx,
+            );
+        });
+
+        Ok(())
+    }
+
+    fn splice_notifications(
+        &mut self,
+        notifications: impl IntoIterator<Item = (u64, Option<NotificationEntry>)>,
+        is_new: bool,
+        cx: &mut ModelContext<'_, NotificationStore>,
+    ) {
+        let mut cursor = self.notifications.cursor::<(NotificationId, Count)>();
+        let mut new_notifications = SumTree::new();
+        let mut old_range = 0..0;
+
+        for (i, (id, new_notification)) in notifications.into_iter().enumerate() {
+            new_notifications.append(cursor.slice(&NotificationId(id), Bias::Left, &()), &());
 
-                if cursor
-                    .item()
-                    .map_or(true, |existing| existing.id != notification.id)
-                {
+            if i == 0 {
+                old_range.start = cursor.start().1 .0;
+            }
+
+            if let Some(existing_notification) = cursor.item() {
+                if existing_notification.id == id {
+                    if new_notification.is_none() {
+                        cx.emit(NotificationEvent::NotificationRemoved {
+                            entry: existing_notification.clone(),
+                        });
+                    }
                     cursor.next(&());
                 }
+            }
 
+            if let Some(notification) = new_notification {
                 if is_new {
                     cx.emit(NotificationEvent::NewNotification {
                         entry: notification.clone(),
@@ -233,20 +297,18 @@ impl NotificationStore {
 
                 new_notifications.push(notification, &());
             }
+        }
 
-            old_range.end = cursor.start().1 .0;
-            let new_count = new_notifications.summary().count;
-            new_notifications.append(cursor.suffix(&()), &());
-            drop(cursor);
+        old_range.end = cursor.start().1 .0;
+        let new_count = new_notifications.summary().count - old_range.start;
+        new_notifications.append(cursor.suffix(&()), &());
+        drop(cursor);
 
-            this.notifications = new_notifications;
-            cx.emit(NotificationEvent::NotificationsUpdated {
-                old_range,
-                new_count,
-            });
+        self.notifications = new_notifications;
+        cx.emit(NotificationEvent::NotificationsUpdated {
+            old_range,
+            new_count,
         });
-
-        Ok(())
     }
 }
 

crates/rpc/proto/zed.proto 🔗

@@ -177,7 +177,8 @@ message Envelope {
 
         NewNotification new_notification = 148;
         GetNotifications get_notifications = 149;
-        GetNotificationsResponse get_notifications_response = 150; // Current max
+        GetNotificationsResponse get_notifications_response = 150;
+        DeleteNotification delete_notification = 151; // Current max
     }
 }
 
@@ -1590,6 +1591,10 @@ message GetNotificationsResponse {
     repeated Notification notifications = 1;
 }
 
+message DeleteNotification {
+    uint64 notification_id = 1;
+}
+
 message Notification {
     uint64 id = 1;
     uint64 timestamp = 2;

crates/rpc/src/notification.rs 🔗

@@ -1,5 +1,5 @@
 use serde::{Deserialize, Serialize};
-use serde_json::Value;
+use serde_json::{map, Value};
 use std::borrow::Cow;
 use strum::{EnumVariantNames, IntoStaticStr, VariantNames as _};
 
@@ -47,10 +47,12 @@ impl Notification {
         let mut value = serde_json::to_value(self).unwrap();
         let mut actor_id = None;
         if let Some(value) = value.as_object_mut() {
-            value.remove("kind");
-            actor_id = value
-                .remove("actor_id")
-                .and_then(|value| Some(value.as_i64()? as u64));
+            value.remove(KIND);
+            if let map::Entry::Occupied(e) = value.entry(ACTOR_ID) {
+                if e.get().is_u64() {
+                    actor_id = e.remove().as_u64();
+                }
+            }
         }
         AnyNotification {
             kind: Cow::Borrowed(kind),
@@ -69,7 +71,7 @@ impl Notification {
         serde_json::from_value(value).ok()
     }
 
-    pub fn all_kinds() -> &'static [&'static str] {
+    pub fn all_variant_names() -> &'static [&'static str] {
         Self::VARIANTS
     }
 }

crates/rpc/src/proto.rs 🔗

@@ -155,6 +155,7 @@ messages!(
     (CreateRoomResponse, Foreground),
     (DeclineCall, Foreground),
     (DeleteChannel, Foreground),
+    (DeleteNotification, Foreground),
     (DeleteProjectEntry, Foreground),
     (Error, Foreground),
     (ExpandProjectEntry, Foreground),