Add the ability to notify when a user accepts a contact request

Antonio Scandurra , Nathan Sobo , and Max Brunsfeld created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
Co-Authored-By: Max Brunsfeld <max@zed.dev>

Change summary

crates/collab/src/db.rs        | 381 +++++++++++++++++++++++++----------
crates/collab/src/rpc.rs       |  48 ++-
crates/collab/src/rpc/store.rs |  37 ++-
3 files changed, 316 insertions(+), 150 deletions(-)

Detailed changes

crates/collab/src/db.rs 🔗

@@ -17,10 +17,11 @@ pub trait Db: Send + Sync {
     async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
     async fn destroy_user(&self, id: UserId) -> Result<()>;
 
-    async fn get_contacts(&self, id: UserId) -> Result<Contacts>;
+    async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>>;
+    async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool>;
     async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
     async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
-    async fn dismiss_contact_request(
+    async fn dismiss_contact_notification(
         &self,
         responder_id: UserId,
         requester_id: UserId,
@@ -190,7 +191,7 @@ impl Db for PostgresDb {
 
     // contacts
 
-    async fn get_contacts(&self, user_id: UserId) -> Result<Contacts> {
+    async fn get_contacts(&self, user_id: UserId) -> Result<Vec<Contact>> {
         let query = "
             SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
             FROM contacts
@@ -201,46 +202,67 @@ impl Db for PostgresDb {
             .bind(user_id)
             .fetch(&self.pool);
 
-        let mut current = vec![user_id];
-        let mut outgoing_requests = Vec::new();
-        let mut incoming_requests = Vec::new();
+        let mut contacts = vec![Contact::Accepted {
+            user_id,
+            should_notify: false,
+        }];
         while let Some(row) = rows.next().await {
             let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 
             if user_id_a == user_id {
                 if accepted {
-                    current.push(user_id_b);
+                    contacts.push(Contact::Accepted {
+                        user_id: user_id_b,
+                        should_notify: should_notify && a_to_b,
+                    });
                 } else if a_to_b {
-                    outgoing_requests.push(user_id_b);
+                    contacts.push(Contact::Outgoing { user_id: user_id_b })
                 } else {
-                    incoming_requests.push(IncomingContactRequest {
-                        requester_id: user_id_b,
+                    contacts.push(Contact::Incoming {
+                        user_id: user_id_b,
                         should_notify,
                     });
                 }
             } else {
                 if accepted {
-                    current.push(user_id_a);
+                    contacts.push(Contact::Accepted {
+                        user_id: user_id_a,
+                        should_notify: should_notify && !a_to_b,
+                    });
                 } else if a_to_b {
-                    incoming_requests.push(IncomingContactRequest {
-                        requester_id: user_id_a,
+                    contacts.push(Contact::Incoming {
+                        user_id: user_id_a,
                         should_notify,
                     });
                 } else {
-                    outgoing_requests.push(user_id_a);
+                    contacts.push(Contact::Outgoing { user_id: user_id_a });
                 }
             }
         }
 
-        current.sort_unstable();
-        outgoing_requests.sort_unstable();
-        incoming_requests.sort_unstable();
+        contacts.sort_unstable_by_key(|contact| contact.user_id());
 
-        Ok(Contacts {
-            current,
-            outgoing_requests,
-            incoming_requests,
-        })
+        Ok(contacts)
+    }
+
+    async fn has_contact(&self, user_id_1: UserId, user_id_2: UserId) -> Result<bool> {
+        let (id_a, id_b) = if user_id_1 < user_id_2 {
+            (user_id_1, user_id_2)
+        } else {
+            (user_id_2, user_id_1)
+        };
+
+        let query = "
+            SELECT 1 FROM contacts
+            WHERE user_id_a = $1 AND user_id_b = $2 AND accepted = 't'
+            LIMIT 1
+        ";
+        Ok(sqlx::query_scalar::<_, i32>(query)
+            .bind(id_a.0)
+            .bind(id_b.0)
+            .fetch_optional(&self.pool)
+            .await?
+            .is_some())
     }
 
     async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
@@ -254,7 +276,8 @@ impl Db for PostgresDb {
             VALUES ($1, $2, $3, 'f', 't')
             ON CONFLICT (user_id_a, user_id_b) DO UPDATE
             SET
-                accepted = 't'
+                accepted = 't',
+                should_notify = 'f'
             WHERE
                 NOT contacts.accepted AND
                 ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
@@ -297,21 +320,26 @@ impl Db for PostgresDb {
         }
     }
 
-    async fn dismiss_contact_request(
+    async fn dismiss_contact_notification(
         &self,
-        responder_id: UserId,
-        requester_id: UserId,
+        user_id: UserId,
+        contact_user_id: UserId,
     ) -> Result<()> {
-        let (id_a, id_b, a_to_b) = if responder_id < requester_id {
-            (responder_id, requester_id, false)
+        let (id_a, id_b, a_to_b) = if user_id < contact_user_id {
+            (user_id, contact_user_id, true)
         } else {
-            (requester_id, responder_id, true)
+            (contact_user_id, user_id, false)
         };
 
         let query = "
             UPDATE contacts
             SET should_notify = 'f'
-            WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
+            WHERE
+                user_id_a = $1 AND user_id_b = $2 AND
+                (
+                    (a_to_b = $3 AND accepted) OR
+                    (a_to_b != $3 AND NOT accepted)
+                );
         ";
 
         let result = sqlx::query(query)
@@ -342,7 +370,7 @@ impl Db for PostgresDb {
         let result = if accept {
             let query = "
                 UPDATE contacts
-                SET accepted = 't', should_notify = 'f'
+                SET accepted = 't', should_notify = 't'
                 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
             ";
             sqlx::query(query)
@@ -702,10 +730,28 @@ pub struct ChannelMessage {
 }
 
 #[derive(Clone, Debug, PartialEq, Eq)]
-pub struct Contacts {
-    pub current: Vec<UserId>,
-    pub incoming_requests: Vec<IncomingContactRequest>,
-    pub outgoing_requests: Vec<UserId>,
+pub enum Contact {
+    Accepted {
+        user_id: UserId,
+        should_notify: bool,
+    },
+    Outgoing {
+        user_id: UserId,
+    },
+    Incoming {
+        user_id: UserId,
+        should_notify: bool,
+    },
+}
+
+impl Contact {
+    pub fn user_id(&self) -> UserId {
+        match self {
+            Contact::Accepted { user_id, .. } => *user_id,
+            Contact::Outgoing { user_id } => *user_id,
+            Contact::Incoming { user_id, .. } => *user_id,
+        }
+    }
 }
 
 #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
@@ -947,51 +993,60 @@ pub mod tests {
             // User starts with no contacts
             assert_eq!(
                 db.get_contacts(user_1).await.unwrap(),
-                Contacts {
-                    current: vec![user_1],
-                    outgoing_requests: vec![],
-                    incoming_requests: vec![],
-                },
+                vec![Contact::Accepted {
+                    user_id: user_1,
+                    should_notify: false
+                }],
             );
 
             // User requests a contact. Both users see the pending request.
             db.send_contact_request(user_1, user_2).await.unwrap();
+            assert!(!db.has_contact(user_1, user_2).await.unwrap());
+            assert!(!db.has_contact(user_2, user_1).await.unwrap());
             assert_eq!(
                 db.get_contacts(user_1).await.unwrap(),
-                Contacts {
-                    current: vec![user_1],
-                    outgoing_requests: vec![user_2],
-                    incoming_requests: vec![],
-                },
+                &[
+                    Contact::Accepted {
+                        user_id: user_1,
+                        should_notify: false
+                    },
+                    Contact::Outgoing { user_id: user_2 }
+                ],
             );
             assert_eq!(
                 db.get_contacts(user_2).await.unwrap(),
-                Contacts {
-                    current: vec![user_2],
-                    outgoing_requests: vec![],
-                    incoming_requests: vec![IncomingContactRequest {
-                        requester_id: user_1,
+                &[
+                    Contact::Incoming {
+                        user_id: user_1,
                         should_notify: true
-                    }],
-                },
+                    },
+                    Contact::Accepted {
+                        user_id: user_2,
+                        should_notify: false
+                    },
+                ]
             );
 
             // User 2 dismisses the contact request notification without accepting or rejecting.
             // We shouldn't notify them again.
-            db.dismiss_contact_request(user_1, user_2)
+            db.dismiss_contact_notification(user_1, user_2)
                 .await
                 .unwrap_err();
-            db.dismiss_contact_request(user_2, user_1).await.unwrap();
+            db.dismiss_contact_notification(user_2, user_1)
+                .await
+                .unwrap();
             assert_eq!(
                 db.get_contacts(user_2).await.unwrap(),
-                Contacts {
-                    current: vec![user_2],
-                    outgoing_requests: vec![],
-                    incoming_requests: vec![IncomingContactRequest {
-                        requester_id: user_1,
+                &[
+                    Contact::Incoming {
+                        user_id: user_1,
                         should_notify: false
-                    }],
-                },
+                    },
+                    Contact::Accepted {
+                        user_id: user_2,
+                        should_notify: false
+                    },
+                ]
             );
 
             // User can't accept their own contact request
@@ -1005,44 +1060,106 @@ pub mod tests {
                 .unwrap();
             assert_eq!(
                 db.get_contacts(user_1).await.unwrap(),
-                Contacts {
-                    current: vec![user_1, user_2],
-                    outgoing_requests: vec![],
-                    incoming_requests: vec![],
-                },
+                &[
+                    Contact::Accepted {
+                        user_id: user_1,
+                        should_notify: false
+                    },
+                    Contact::Accepted {
+                        user_id: user_2,
+                        should_notify: true
+                    }
+                ],
             );
+            assert!(db.has_contact(user_1, user_2).await.unwrap());
+            assert!(db.has_contact(user_2, user_1).await.unwrap());
             assert_eq!(
                 db.get_contacts(user_2).await.unwrap(),
-                Contacts {
-                    current: vec![user_1, user_2],
-                    outgoing_requests: vec![],
-                    incoming_requests: vec![],
-                },
+                &[
+                    Contact::Accepted {
+                        user_id: user_1,
+                        should_notify: false,
+                    },
+                    Contact::Accepted {
+                        user_id: user_2,
+                        should_notify: false,
+                    },
+                ]
             );
 
             // Users cannot re-request existing contacts.
             db.send_contact_request(user_1, user_2).await.unwrap_err();
             db.send_contact_request(user_2, user_1).await.unwrap_err();
 
+            // Users can't dismiss notifications of them accepting other users' requests.
+            db.dismiss_contact_notification(user_2, user_1)
+                .await
+                .unwrap_err();
+            assert_eq!(
+                db.get_contacts(user_1).await.unwrap(),
+                &[
+                    Contact::Accepted {
+                        user_id: user_1,
+                        should_notify: false
+                    },
+                    Contact::Accepted {
+                        user_id: user_2,
+                        should_notify: true,
+                    },
+                ]
+            );
+
+            // Users can dismiss notifications of other users accepting their requests.
+            db.dismiss_contact_notification(user_1, user_2)
+                .await
+                .unwrap();
+            assert_eq!(
+                db.get_contacts(user_1).await.unwrap(),
+                &[
+                    Contact::Accepted {
+                        user_id: user_1,
+                        should_notify: false
+                    },
+                    Contact::Accepted {
+                        user_id: user_2,
+                        should_notify: false,
+                    },
+                ]
+            );
+
             // Users send each other concurrent contact requests and
             // see that they are immediately accepted.
             db.send_contact_request(user_1, user_3).await.unwrap();
             db.send_contact_request(user_3, user_1).await.unwrap();
             assert_eq!(
                 db.get_contacts(user_1).await.unwrap(),
-                Contacts {
-                    current: vec![user_1, user_2, user_3],
-                    outgoing_requests: vec![],
-                    incoming_requests: vec![],
-                },
+                &[
+                    Contact::Accepted {
+                        user_id: user_1,
+                        should_notify: false
+                    },
+                    Contact::Accepted {
+                        user_id: user_2,
+                        should_notify: false,
+                    },
+                    Contact::Accepted {
+                        user_id: user_3,
+                        should_notify: false
+                    },
+                ]
             );
             assert_eq!(
                 db.get_contacts(user_3).await.unwrap(),
-                Contacts {
-                    current: vec![user_1, user_3],
-                    outgoing_requests: vec![],
-                    incoming_requests: vec![],
-                },
+                &[
+                    Contact::Accepted {
+                        user_id: user_1,
+                        should_notify: false
+                    },
+                    Contact::Accepted {
+                        user_id: user_3,
+                        should_notify: false
+                    }
+                ],
             );
 
             // User declines a contact request. Both users see that it is gone.
@@ -1050,21 +1167,33 @@ pub mod tests {
             db.respond_to_contact_request(user_3, user_2, false)
                 .await
                 .unwrap();
+            assert!(!db.has_contact(user_2, user_3).await.unwrap());
+            assert!(!db.has_contact(user_3, user_2).await.unwrap());
             assert_eq!(
                 db.get_contacts(user_2).await.unwrap(),
-                Contacts {
-                    current: vec![user_1, user_2],
-                    outgoing_requests: vec![],
-                    incoming_requests: vec![],
-                },
+                &[
+                    Contact::Accepted {
+                        user_id: user_1,
+                        should_notify: false
+                    },
+                    Contact::Accepted {
+                        user_id: user_2,
+                        should_notify: false
+                    }
+                ]
             );
             assert_eq!(
                 db.get_contacts(user_3).await.unwrap(),
-                Contacts {
-                    current: vec![user_1, user_3],
-                    outgoing_requests: vec![],
-                    incoming_requests: vec![],
-                },
+                &[
+                    Contact::Accepted {
+                        user_id: user_1,
+                        should_notify: false
+                    },
+                    Contact::Accepted {
+                        user_id: user_3,
+                        should_notify: false
+                    }
+                ],
             );
         }
     }
@@ -1219,40 +1348,51 @@ pub mod tests {
             unimplemented!()
         }
 
-        async fn get_contacts(&self, id: UserId) -> Result<Contacts> {
+        async fn get_contacts(&self, id: UserId) -> Result<Vec<Contact>> {
             self.background.simulate_random_delay().await;
-            let mut current = vec![id];
-            let mut outgoing_requests = Vec::new();
-            let mut incoming_requests = Vec::new();
+            let mut contacts = vec![Contact::Accepted {
+                user_id: id,
+                should_notify: false,
+            }];
 
             for contact in self.contacts.lock().iter() {
                 if contact.requester_id == id {
                     if contact.accepted {
-                        current.push(contact.responder_id);
+                        contacts.push(Contact::Accepted {
+                            user_id: contact.responder_id,
+                            should_notify: contact.should_notify,
+                        });
                     } else {
-                        outgoing_requests.push(contact.responder_id);
+                        contacts.push(Contact::Outgoing {
+                            user_id: contact.responder_id,
+                        });
                     }
                 } else if contact.responder_id == id {
                     if contact.accepted {
-                        current.push(contact.requester_id);
+                        contacts.push(Contact::Accepted {
+                            user_id: contact.requester_id,
+                            should_notify: false,
+                        });
                     } else {
-                        incoming_requests.push(IncomingContactRequest {
-                            requester_id: contact.requester_id,
+                        contacts.push(Contact::Incoming {
+                            user_id: contact.requester_id,
                             should_notify: contact.should_notify,
                         });
                     }
                 }
             }
 
-            current.sort_unstable();
-            outgoing_requests.sort_unstable();
-            incoming_requests.sort_unstable();
+            contacts.sort_unstable_by_key(|contact| contact.user_id());
+            Ok(contacts)
+        }
 
-            Ok(Contacts {
-                current,
-                outgoing_requests,
-                incoming_requests,
-            })
+        async fn has_contact(&self, user_id_a: UserId, user_id_b: UserId) -> Result<bool> {
+            self.background.simulate_random_delay().await;
+            Ok(self.contacts.lock().iter().any(|contact| {
+                contact.accepted
+                    && ((contact.requester_id == user_id_a && contact.responder_id == user_id_b)
+                        || (contact.requester_id == user_id_b && contact.responder_id == user_id_a))
+            }))
         }
 
         async fn send_contact_request(
@@ -1274,6 +1414,7 @@ pub mod tests {
                         Err(anyhow!("contact already exists"))?;
                     } else {
                         contact.accepted = true;
+                        contact.should_notify = false;
                         return Ok(());
                     }
                 }
@@ -1294,22 +1435,29 @@ pub mod tests {
             Ok(())
         }
 
-        async fn dismiss_contact_request(
+        async fn dismiss_contact_notification(
             &self,
-            responder_id: UserId,
-            requester_id: UserId,
+            user_id: UserId,
+            contact_user_id: UserId,
         ) -> Result<()> {
             let mut contacts = self.contacts.lock();
             for contact in contacts.iter_mut() {
-                if contact.requester_id == requester_id && contact.responder_id == responder_id {
-                    if contact.accepted {
-                        return Err(anyhow!("contact already confirmed"));
-                    }
+                if contact.requester_id == contact_user_id
+                    && contact.responder_id == user_id
+                    && !contact.accepted
+                {
+                    contact.should_notify = false;
+                    return Ok(());
+                }
+                if contact.requester_id == user_id
+                    && contact.responder_id == contact_user_id
+                    && contact.accepted
+                {
                     contact.should_notify = false;
                     return Ok(());
                 }
             }
-            Err(anyhow!("no such contact request"))
+            Err(anyhow!("no such notification"))
         }
 
         async fn respond_to_contact_request(
@@ -1326,6 +1474,7 @@ pub mod tests {
                     }
                     if accept {
                         contact.accepted = true;
+                        contact.should_notify = true;
                     } else {
                         contacts.remove(ix);
                     }

crates/collab/src/rpc.rs 🔗

@@ -2,7 +2,7 @@ mod store;
 
 use crate::{
     auth,
-    db::{ChannelId, MessageId, UserId},
+    db::{self, ChannelId, MessageId, UserId},
     AppState, Result,
 };
 use anyhow::anyhow;
@@ -421,21 +421,27 @@ impl Server {
         let contacts = self.app_state.db.get_contacts(user_id).await?;
         let store = self.store().await;
         let updated_contact = store.contact_for_user(user_id);
-        for contact_user_id in contacts.current {
-            for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
-                self.peer
-                    .send(
-                        contact_conn_id,
-                        proto::UpdateContacts {
-                            contacts: vec![updated_contact.clone()],
-                            remove_contacts: Default::default(),
-                            incoming_requests: Default::default(),
-                            remove_incoming_requests: Default::default(),
-                            outgoing_requests: Default::default(),
-                            remove_outgoing_requests: Default::default(),
-                        },
-                    )
-                    .trace_err();
+        for contact in contacts {
+            if let db::Contact::Accepted {
+                user_id: contact_user_id,
+                ..
+            } = contact
+            {
+                for contact_conn_id in store.connection_ids_for_user(contact_user_id) {
+                    self.peer
+                        .send(
+                            contact_conn_id,
+                            proto::UpdateContacts {
+                                contacts: vec![updated_contact.clone()],
+                                remove_contacts: Default::default(),
+                                incoming_requests: Default::default(),
+                                remove_incoming_requests: Default::default(),
+                                outgoing_requests: Default::default(),
+                                remove_outgoing_requests: Default::default(),
+                            },
+                        )
+                        .trace_err();
+                }
             }
         }
         Ok(())
@@ -473,8 +479,12 @@ impl Server {
             guest_user_id = state.user_id_for_connection(request.sender_id)?;
         };
 
-        let guest_contacts = self.app_state.db.get_contacts(guest_user_id).await?;
-        if !guest_contacts.current.contains(&host_user_id) {
+        let has_contact = self
+            .app_state
+            .db
+            .has_contact(guest_user_id, host_user_id)
+            .await?;
+        if !has_contact {
             return Err(anyhow!("no such project"))?;
         }
 
@@ -1026,7 +1036,7 @@ impl Server {
         if request.payload.response == proto::ContactRequestResponse::Dismiss as i32 {
             self.app_state
                 .db
-                .dismiss_contact_request(responder_id, requester_id)
+                .dismiss_contact_notification(responder_id, requester_id)
                 .await?;
         } else {
             let accept = request.payload.response == proto::ContactRequestResponse::Accept as i32;

crates/collab/src/rpc/store.rs 🔗

@@ -217,23 +217,30 @@ impl Store {
             .is_empty()
     }
 
-    pub fn build_initial_contacts_update(&self, contacts: db::Contacts) -> proto::UpdateContacts {
+    pub fn build_initial_contacts_update(
+        &self,
+        contacts: Vec<db::Contact>,
+    ) -> proto::UpdateContacts {
         let mut update = proto::UpdateContacts::default();
-        for user_id in contacts.current {
-            update.contacts.push(self.contact_for_user(user_id));
-        }
-
-        for request in contacts.incoming_requests {
-            update
-                .incoming_requests
-                .push(proto::IncomingContactRequest {
-                    requester_id: request.requester_id.to_proto(),
-                    should_notify: request.should_notify,
-                })
-        }
 
-        for requested_user_id in contacts.outgoing_requests {
-            update.outgoing_requests.push(requested_user_id.to_proto())
+        for contact in contacts {
+            match contact {
+                db::Contact::Accepted { user_id, .. } => {
+                    update.contacts.push(self.contact_for_user(user_id));
+                }
+                db::Contact::Outgoing { user_id } => {
+                    update.outgoing_requests.push(user_id.to_proto())
+                }
+                db::Contact::Incoming {
+                    user_id,
+                    should_notify,
+                } => update
+                    .incoming_requests
+                    .push(proto::IncomingContactRequest {
+                        requester_id: user_id.to_proto(),
+                        should_notify,
+                    }),
+            }
         }
 
         update