Start work on RPC endpoints for dealing with contact requests

Max Brunsfeld and Nathan Sobo created

Co-authored-by: Nathan Sobo <nathan@zed.dev>

Change summary

crates/client/src/user.rs                                   |  42 ++
crates/collab/migrations/20220506130724_create_contacts.sql |   1 
crates/collab/src/db.rs                                     | 114 ++++++
crates/collab/src/rpc.rs                                    |  71 ++++
crates/rpc/src/macros.rs                                    |  67 ++++
crates/rpc/src/proto.rs                                     |  71 ----
crates/rpc/src/rpc.rs                                       |   1 
7 files changed, 289 insertions(+), 78 deletions(-)

Detailed changes

crates/client/src/user.rs 🔗

@@ -1,6 +1,6 @@
 use super::{http::HttpClient, proto, Client, Status, TypedEnvelope};
 use anyhow::{anyhow, Result};
-use futures::{future, AsyncReadExt};
+use futures::{future, AsyncReadExt, Future};
 use gpui::{AsyncAppContext, Entity, ImageData, ModelContext, ModelHandle, Task};
 use postage::{prelude::Stream, sink::Sink, watch};
 use rpc::proto::{RequestMessage, UsersResponse};
@@ -121,6 +121,13 @@ impl UserStore {
             user_ids.insert(contact.user_id);
             user_ids.extend(contact.projects.iter().flat_map(|w| &w.guests).copied());
         }
+        user_ids.extend(message.pending_requests_to_user_ids.iter());
+        user_ids.extend(
+            message
+                .pending_requests_from_user_ids
+                .iter()
+                .map(|req| req.user_id),
+        );
 
         let load_users = self.get_users(user_ids.into_iter().collect(), cx);
         cx.spawn(|this, mut cx| async move {
@@ -153,6 +160,39 @@ impl UserStore {
             .is_ok()
     }
 
+    pub fn request_contact(&self, to_user_id: u64) -> impl Future<Output = Result<()>> {
+        let client = self.client.upgrade();
+        async move {
+            client
+                .ok_or_else(|| anyhow!("not logged in"))?
+                .request(proto::RequestContact { to_user_id })
+                .await?;
+            Ok(())
+        }
+    }
+
+    pub fn respond_to_contact_request(
+        &self,
+        from_user_id: u64,
+        accept: bool,
+    ) -> impl Future<Output = Result<()>> {
+        let client = self.client.upgrade();
+        async move {
+            client
+                .ok_or_else(|| anyhow!("not logged in"))?
+                .request(proto::RespondToContactRequest {
+                    requesting_user_id: from_user_id,
+                    response: if accept {
+                        proto::ContactRequestResponse::Accept
+                    } else {
+                        proto::ContactRequestResponse::Reject
+                    } as i32,
+                })
+                .await?;
+            Ok(())
+        }
+    }
+
     pub fn get_users(
         &mut self,
         mut user_ids: Vec<u64>,

crates/collab/migrations/20220506130724_create_contacts.sql 🔗

@@ -3,6 +3,7 @@ CREATE TABLE IF NOT EXISTS "contacts" (
     "user_id_a" INTEGER REFERENCES users (id) NOT NULL,
     "user_id_b" INTEGER REFERENCES users (id) NOT NULL,
     "a_to_b" BOOLEAN NOT NULL,
+    "should_notify" BOOLEAN NOT NULL,
     "accepted" BOOLEAN NOT NULL
 );
 

crates/collab/src/db.rs 🔗

@@ -19,6 +19,11 @@ pub trait Db: Send + Sync {
 
     async fn get_contacts(&self, id: UserId) -> Result<Contacts>;
     async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
+    async fn dismiss_contact_request(
+        &self,
+        responder_id: UserId,
+        requester_id: UserId,
+    ) -> Result<()>;
     async fn respond_to_contact_request(
         &self,
         responder_id: UserId,
@@ -184,12 +189,12 @@ impl Db for PostgresDb {
 
     async fn get_contacts(&self, user_id: UserId) -> Result<Contacts> {
         let query = "
-            SELECT user_id_a, user_id_b, a_to_b, accepted
+            SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
             FROM contacts
             WHERE user_id_a = $1 OR user_id_b = $1;
         ";
 
-        let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool)>(query)
+        let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
             .bind(user_id)
             .fetch(&self.pool);
 
@@ -197,7 +202,7 @@ impl Db for PostgresDb {
         let mut requests_sent = Vec::new();
         let mut requests_received = Vec::new();
         while let Some(row) = rows.next().await {
-            let (user_id_a, user_id_b, a_to_b, accepted) = row?;
+            let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
 
             if user_id_a == user_id {
                 if accepted {
@@ -205,13 +210,19 @@ impl Db for PostgresDb {
                 } else if a_to_b {
                     requests_sent.push(user_id_b);
                 } else {
-                    requests_received.push(user_id_b);
+                    requests_received.push(IncomingContactRequest {
+                        requesting_user_id: user_id_b,
+                        should_notify,
+                    });
                 }
             } else {
                 if accepted {
                     current.push(user_id_a);
                 } else if a_to_b {
-                    requests_received.push(user_id_a);
+                    requests_received.push(IncomingContactRequest {
+                        requesting_user_id: user_id_a,
+                        should_notify,
+                    });
                 } else {
                     requests_sent.push(user_id_a);
                 }
@@ -232,8 +243,8 @@ impl Db for PostgresDb {
             (receiver_id, sender_id, false)
         };
         let query = "
-            INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted)
-            VALUES ($1, $2, $3, 'f')
+            INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
+            VALUES ($1, $2, $3, 'f', 't')
             ON CONFLICT (user_id_a, user_id_b) DO UPDATE
             SET
                 accepted = 't'
@@ -270,7 +281,7 @@ impl Db for PostgresDb {
         let result = if accept {
             let query = "
                 UPDATE contacts
-                SET accepted = 't'
+                SET accepted = 't', should_notify = 'f'
                 WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
             ";
             sqlx::query(query)
@@ -298,6 +309,37 @@ impl Db for PostgresDb {
         }
     }
 
+    async fn dismiss_contact_request(
+        &self,
+        responder_id: UserId,
+        requester_id: UserId,
+    ) -> Result<()> {
+        let (id_a, id_b, a_to_b) = if responder_id < requester_id {
+            (responder_id, requester_id, false)
+        } else {
+            (requester_id, responder_id, true)
+        };
+
+        let query = "
+            UPDATE contacts
+            SET should_notify = 'f'
+            WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
+        ";
+
+        let result = sqlx::query(query)
+            .bind(id_a.0)
+            .bind(id_b.0)
+            .bind(a_to_b)
+            .execute(&self.pool)
+            .await?;
+
+        if result.rows_affected() == 0 {
+            Err(anyhow!("no such contact request"))?;
+        }
+
+        Ok(())
+    }
+
     // access tokens
 
     async fn create_access_token_hash(
@@ -628,7 +670,13 @@ pub struct ChannelMessage {
 pub struct Contacts {
     pub current: Vec<UserId>,
     pub requests_sent: Vec<UserId>,
-    pub requests_received: Vec<UserId>,
+    pub requests_received: Vec<IncomingContactRequest>,
+}
+
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub struct IncomingContactRequest {
+    pub requesting_user_id: UserId,
+    pub should_notify: bool,
 }
 
 fn fuzzy_like_string(string: &str) -> String {
@@ -886,7 +934,28 @@ pub mod tests {
                 Contacts {
                     current: vec![],
                     requests_sent: vec![],
-                    requests_received: vec![user_1],
+                    requests_received: vec![IncomingContactRequest {
+                        requesting_user_id: user_1,
+                        should_notify: true
+                    }],
+                },
+            );
+
+            // User 2 dismisses the contact request notification without accepting or rejecting.
+            // We shouldn't notify them again.
+            db.dismiss_contact_request(user_1, user_2)
+                .await
+                .unwrap_err();
+            db.dismiss_contact_request(user_2, user_1).await.unwrap();
+            assert_eq!(
+                db.get_contacts(user_2).await.unwrap(),
+                Contacts {
+                    current: vec![],
+                    requests_sent: vec![],
+                    requests_received: vec![IncomingContactRequest {
+                        requesting_user_id: user_1,
+                        should_notify: false
+                    }],
                 },
             );
 
@@ -1032,6 +1101,7 @@ pub mod tests {
         requester_id: UserId,
         responder_id: UserId,
         accepted: bool,
+        should_notify: bool,
     }
 
     impl FakeDb {
@@ -1124,7 +1194,10 @@ pub mod tests {
                     if contact.accepted {
                         current.push(contact.requester_id);
                     } else {
-                        requests_received.push(contact.requester_id);
+                        requests_received.push(IncomingContactRequest {
+                            requesting_user_id: contact.requester_id,
+                            should_notify: contact.should_notify,
+                        });
                     }
                 }
             }
@@ -1162,10 +1235,29 @@ pub mod tests {
                 requester_id,
                 responder_id,
                 accepted: false,
+                should_notify: true,
             });
             Ok(())
         }
 
+        async fn dismiss_contact_request(
+            &self,
+            responder_id: UserId,
+            requester_id: UserId,
+        ) -> Result<()> {
+            let mut contacts = self.contacts.lock();
+            for contact in contacts.iter_mut() {
+                if contact.requester_id == requester_id && contact.responder_id == responder_id {
+                    if contact.accepted {
+                        return Err(anyhow!("contact already confirmed"));
+                    }
+                    contact.should_notify = false;
+                    return Ok(());
+                }
+            }
+            Err(anyhow!("no such contact request"))
+        }
+
         async fn respond_to_contact_request(
             &self,
             responder_id: UserId,

crates/collab/src/rpc.rs 🔗

@@ -154,6 +154,8 @@ impl Server {
             .add_request_handler(Server::get_channels)
             .add_request_handler(Server::get_users)
             .add_request_handler(Server::fuzzy_search_users)
+            .add_request_handler(Server::request_contact)
+            .add_request_handler(Server::respond_to_contact_request)
             .add_request_handler(Server::join_channel)
             .add_message_handler(Server::leave_channel)
             .add_request_handler(Server::send_channel_message)
@@ -914,6 +916,48 @@ impl Server {
         Ok(())
     }
 
+    async fn request_contact(
+        self: Arc<Server>,
+        request: TypedEnvelope<proto::RequestContact>,
+        response: Response<proto::RequestContact>,
+    ) -> Result<()> {
+        let requester_id = self
+            .store
+            .read()
+            .await
+            .user_id_for_connection(request.sender_id)?;
+        let responder_id = UserId::from_proto(request.payload.to_user_id);
+        self.app_state
+            .db
+            .send_contact_request(requester_id, responder_id)
+            .await?;
+        response.send(proto::Ack {})?;
+        Ok(())
+    }
+
+    async fn respond_to_contact_request(
+        self: Arc<Server>,
+        request: TypedEnvelope<proto::RespondToContactRequest>,
+        response: Response<proto::RespondToContactRequest>,
+    ) -> Result<()> {
+        let responder_id = self
+            .store
+            .read()
+            .await
+            .user_id_for_connection(request.sender_id)?;
+        let requester_id = UserId::from_proto(request.payload.requesting_user_id);
+        self.app_state
+            .db
+            .respond_to_contact_request(
+                responder_id,
+                requester_id,
+                request.payload.response == proto::ContactRequestResponse::Accept as i32,
+            )
+            .await?;
+        response.send(proto::Ack {})?;
+        Ok(())
+    }
+
     #[instrument(skip(self, state, user_ids))]
     fn update_contacts_for_users<'a>(
         self: &Arc<Self>,
@@ -4911,6 +4955,33 @@ mod tests {
         }
     }
 
+    #[gpui::test(iterations = 10)]
+    async fn test_contacts_requests(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
+        cx_a.foreground().forbid_parking();
+
+        // Connect to a server as 3 clients.
+        let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+        let client_a = server.create_client(cx_a, "user_a").await;
+        let client_b = server.create_client(cx_b, "user_b").await;
+
+        client_a
+            .user_store
+            .read_with(cx_a, |store, _| {
+                store.request_contact(client_b.user_id().unwrap())
+            })
+            .await
+            .unwrap();
+
+        client_a.user_store.read_with(cx_a, |store, _| {
+            let contacts = store
+                .contacts()
+                .iter()
+                .map(|contact| contact.user.github_login.clone())
+                .collect::<Vec<_>>();
+            assert_eq!(contacts, &["user_b"])
+        });
+    }
+
     #[gpui::test(iterations = 10)]
     async fn test_following(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
         cx_a.foreground().forbid_parking();

crates/rpc/src/macros.rs 🔗

@@ -0,0 +1,67 @@
+#[macro_export]
+macro_rules! messages {
+    ($(($name:ident, $priority:ident)),* $(,)?) => {
+        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
+            match envelope.payload {
+                $(Some(envelope::Payload::$name(payload)) => {
+                    Some(Box::new(TypedEnvelope {
+                        sender_id,
+                        original_sender_id: envelope.original_sender_id.map(PeerId),
+                        message_id: envelope.id,
+                        payload,
+                    }))
+                }, )*
+                _ => None
+            }
+        }
+
+        $(
+            impl EnvelopedMessage for $name {
+                const NAME: &'static str = std::stringify!($name);
+                const PRIORITY: MessagePriority = MessagePriority::$priority;
+
+                fn into_envelope(
+                    self,
+                    id: u32,
+                    responding_to: Option<u32>,
+                    original_sender_id: Option<u32>,
+                ) -> Envelope {
+                    Envelope {
+                        id,
+                        responding_to,
+                        original_sender_id,
+                        payload: Some(envelope::Payload::$name(self)),
+                    }
+                }
+
+                fn from_envelope(envelope: Envelope) -> Option<Self> {
+                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
+                        Some(msg)
+                    } else {
+                        None
+                    }
+                }
+            }
+        )*
+    };
+}
+
+#[macro_export]
+macro_rules! request_messages {
+    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
+        $(impl RequestMessage for $request_name {
+            type Response = $response_name;
+        })*
+    };
+}
+
+#[macro_export]
+macro_rules! entity_messages {
+    ($id_field:ident, $($name:ident),* $(,)?) => {
+        $(impl EntityMessage for $name {
+            fn remote_entity_id(&self) -> u64 {
+                self.$id_field
+            }
+        })*
+    };
+}

crates/rpc/src/proto.rs 🔗

@@ -1,4 +1,4 @@
-use super::{ConnectionId, PeerId, TypedEnvelope};
+use super::{entity_messages, messages, request_messages, ConnectionId, PeerId, TypedEnvelope};
 use anyhow::{anyhow, Result};
 use async_tungstenite::tungstenite::Message as WebSocketMessage;
 use futures::{SinkExt as _, StreamExt as _};
@@ -73,71 +73,6 @@ impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
     }
 }
 
-macro_rules! messages {
-    ($(($name:ident, $priority:ident)),* $(,)?) => {
-        pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
-            match envelope.payload {
-                $(Some(envelope::Payload::$name(payload)) => {
-                    Some(Box::new(TypedEnvelope {
-                        sender_id,
-                        original_sender_id: envelope.original_sender_id.map(PeerId),
-                        message_id: envelope.id,
-                        payload,
-                    }))
-                }, )*
-                _ => None
-            }
-        }
-
-        $(
-            impl EnvelopedMessage for $name {
-                const NAME: &'static str = std::stringify!($name);
-                const PRIORITY: MessagePriority = MessagePriority::$priority;
-
-                fn into_envelope(
-                    self,
-                    id: u32,
-                    responding_to: Option<u32>,
-                    original_sender_id: Option<u32>,
-                ) -> Envelope {
-                    Envelope {
-                        id,
-                        responding_to,
-                        original_sender_id,
-                        payload: Some(envelope::Payload::$name(self)),
-                    }
-                }
-
-                fn from_envelope(envelope: Envelope) -> Option<Self> {
-                    if let Some(envelope::Payload::$name(msg)) = envelope.payload {
-                        Some(msg)
-                    } else {
-                        None
-                    }
-                }
-            }
-        )*
-    };
-}
-
-macro_rules! request_messages {
-    ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
-        $(impl RequestMessage for $request_name {
-            type Response = $response_name;
-        })*
-    };
-}
-
-macro_rules! entity_messages {
-    ($id_field:ident, $($name:ident),* $(,)?) => {
-        $(impl EntityMessage for $name {
-            fn remote_entity_id(&self) -> u64 {
-                self.$id_field
-            }
-        })*
-    };
-}
-
 messages!(
     (Ack, Foreground),
     (AddProjectCollaborator, Foreground),
@@ -198,6 +133,8 @@ messages!(
     (ReloadBuffersResponse, Foreground),
     (RemoveProjectCollaborator, Foreground),
     (RenameProjectEntry, Foreground),
+    (RequestContact, Foreground),
+    (RespondToContactRequest, Foreground),
     (SaveBuffer, Foreground),
     (SearchProject, Background),
     (SearchProjectResponse, Background),
@@ -250,6 +187,8 @@ request_messages!(
     (RegisterProject, RegisterProjectResponse),
     (RegisterWorktree, Ack),
     (ReloadBuffers, ReloadBuffersResponse),
+    (RequestContact, Ack),
+    (RespondToContactRequest, Ack),
     (RenameProjectEntry, ProjectEntryResponse),
     (SaveBuffer, BufferSaved),
     (SearchProject, SearchProjectResponse),

crates/rpc/src/rpc.rs 🔗

@@ -4,5 +4,6 @@ mod peer;
 pub mod proto;
 pub use conn::Connection;
 pub use peer::*;
+mod macros;
 
 pub const PROTOCOL_VERSION: u32 = 16;