Implement persistence for contacts

Max Brunsfeld and Nathan Sobo created

Co-authored-by: Nathan Sobo <nathan@zed.dev>

Change summary

crates/collab/migrations/20220506130724_create_contacts.sql |  12 
crates/collab/src/db.rs                                     | 346 ++++++
2 files changed, 350 insertions(+), 8 deletions(-)

Detailed changes

crates/collab/migrations/20220506130724_create_contacts.sql 🔗

@@ -1,10 +1,10 @@
 CREATE TABLE IF NOT EXISTS "contacts" (
     "id" SERIAL PRIMARY KEY,
-    "requesting_user_id" INTEGER REFERENCES users (id) NOT NULL,
-    "receiving_user_id" INTEGER REFERENCES users (id) NOT NULL,
-    "accepted" BOOLEAN NOT NULL,
-    "blocked" BOOLEAN NOT NULL
+    "user_id_a" INTEGER REFERENCES users (id) NOT NULL,
+    "user_id_b" INTEGER REFERENCES users (id) NOT NULL,
+    "a_to_b" BOOLEAN NOT NULL,
+    "accepted" BOOLEAN NOT NULL
 );
 
-CREATE UNIQUE INDEX "index_org_contacts_requesting_user_id_and_receiving_user_id" ON "contacts" ("requesting_user_id", "receiving_user_id");
-CREATE UNIQUE INDEX "index_org_contacts_receiving_user" ON "contacts" ("receiving_user_id");
+CREATE UNIQUE INDEX "index_contacts_user_ids" ON "contacts" ("user_id_a", "user_id_b");
+CREATE INDEX "index_contacts_user_id_b" ON "contacts" ("user_id_b");

crates/collab/src/db.rs 🔗

@@ -1,6 +1,6 @@
-use anyhow::Context;
-use anyhow::Result;
+use anyhow::{anyhow, Context, Result};
 use async_trait::async_trait;
+use futures::StreamExt;
 use serde::Serialize;
 pub use sqlx::postgres::PgPoolOptions as DbOptions;
 use sqlx::{types::Uuid, FromRow};
@@ -16,6 +16,16 @@ pub trait Db: Send + Sync {
     async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
     async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
     async fn destroy_user(&self, id: UserId) -> Result<()>;
+
+    async fn get_contacts(&self, id: UserId) -> Result<Contacts>;
+    async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
+    async fn respond_to_contact_request(
+        &self,
+        responder_id: UserId,
+        requester_id: UserId,
+        accept: bool,
+    ) -> Result<()>;
+
     async fn create_access_token_hash(
         &self,
         user_id: UserId,
@@ -24,6 +34,7 @@ pub trait Db: Send + Sync {
     ) -> Result<()>;
     async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
     #[cfg(any(test, feature = "seed-support"))]
+
     async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
     #[cfg(any(test, feature = "seed-support"))]
     async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
@@ -32,6 +43,7 @@ pub trait Db: Send + Sync {
     #[cfg(any(test, feature = "seed-support"))]
     async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
     #[cfg(any(test, feature = "seed-support"))]
+
     async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
     async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
     async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
@@ -168,6 +180,124 @@ impl Db for PostgresDb {
             .map(drop)?)
     }
 
+    // contacts
+
+    async fn get_contacts(&self, user_id: UserId) -> Result<Contacts> {
+        let query = "
+            SELECT user_id_a, user_id_b, a_to_b, accepted
+            FROM contacts
+            WHERE user_id_a = $1 OR user_id_b = $1;
+        ";
+
+        let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool)>(query)
+            .bind(user_id)
+            .fetch(&self.pool);
+
+        let mut current = Vec::new();
+        let mut requests_sent = Vec::new();
+        let mut requests_received = Vec::new();
+        while let Some(row) = rows.next().await {
+            let (user_id_a, user_id_b, a_to_b, accepted) = row?;
+
+            if user_id_a == user_id {
+                if accepted {
+                    current.push(user_id_b);
+                } else if a_to_b {
+                    requests_sent.push(user_id_b);
+                } else {
+                    requests_received.push(user_id_b);
+                }
+            } else {
+                if accepted {
+                    current.push(user_id_a);
+                } else if a_to_b {
+                    requests_received.push(user_id_a);
+                } else {
+                    requests_sent.push(user_id_a);
+                }
+            }
+        }
+
+        Ok(Contacts {
+            current,
+            requests_sent,
+            requests_received,
+        })
+    }
+
+    async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> {
+        let (id_a, id_b, a_to_b) = if sender_id < receiver_id {
+            (sender_id, receiver_id, true)
+        } else {
+            (receiver_id, sender_id, false)
+        };
+        let query = "
+            INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted)
+            VALUES ($1, $2, $3, 'f')
+            ON CONFLICT (user_id_a, user_id_b) DO UPDATE
+            SET
+                accepted = 't'
+            WHERE
+                NOT contacts.accepted AND
+                ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR
+                (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a));
+        ";
+        let result = sqlx::query(query)
+            .bind(id_a.0)
+            .bind(id_b.0)
+            .bind(a_to_b)
+            .execute(&self.pool)
+            .await?;
+
+        if result.rows_affected() == 1 {
+            Ok(())
+        } else {
+            Err(anyhow!("contact already requested"))
+        }
+    }
+
+    async fn respond_to_contact_request(
+        &self,
+        responder_id: UserId,
+        requester_id: UserId,
+        accept: bool,
+    ) -> Result<()> {
+        let (id_a, id_b, a_to_b) = if responder_id < requester_id {
+            (responder_id, requester_id, false)
+        } else {
+            (requester_id, responder_id, true)
+        };
+        let result = if accept {
+            let query = "
+                UPDATE contacts
+                SET accepted = 't'
+                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
+            ";
+            sqlx::query(query)
+                .bind(id_a.0)
+                .bind(id_b.0)
+                .bind(a_to_b)
+                .execute(&self.pool)
+                .await?
+        } else {
+            let query = "
+                DELETE FROM contacts
+                WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted;
+            ";
+            sqlx::query(query)
+                .bind(id_a.0)
+                .bind(id_b.0)
+                .bind(a_to_b)
+                .execute(&self.pool)
+                .await?
+        };
+        if result.rows_affected() == 1 {
+            Ok(())
+        } else {
+            Err(anyhow!("no such contact request"))
+        }
+    }
+
     // access tokens
 
     async fn create_access_token_hash(
@@ -494,6 +624,13 @@ pub struct ChannelMessage {
     pub nonce: Uuid,
 }
 
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub struct Contacts {
+    pub current: Vec<UserId>,
+    pub requests_sent: Vec<UserId>,
+    pub requests_received: Vec<UserId>,
+}
+
 fn fuzzy_like_string(string: &str) -> String {
     let mut result = String::with_capacity(string.len() * 2 + 1);
     for c in string.chars() {
@@ -712,6 +849,122 @@ pub mod tests {
         }
     }
 
+    #[tokio::test(flavor = "multi_thread")]
+    async fn test_add_contacts() {
+        for test_db in [
+            TestDb::postgres().await,
+            TestDb::fake(Arc::new(gpui::executor::Background::new())),
+        ] {
+            let db = test_db.db();
+
+            let user_1 = db.create_user("user1", false).await.unwrap();
+            let user_2 = db.create_user("user2", false).await.unwrap();
+            let user_3 = db.create_user("user3", false).await.unwrap();
+
+            // User starts with no contacts
+            assert_eq!(
+                db.get_contacts(user_1).await.unwrap(),
+                Contacts {
+                    current: vec![],
+                    requests_sent: vec![],
+                    requests_received: vec![],
+                },
+            );
+
+            // User requests a contact. Both users see the pending request.
+            db.send_contact_request(user_1, user_2).await.unwrap();
+            assert_eq!(
+                db.get_contacts(user_1).await.unwrap(),
+                Contacts {
+                    current: vec![],
+                    requests_sent: vec![user_2],
+                    requests_received: vec![],
+                },
+            );
+            assert_eq!(
+                db.get_contacts(user_2).await.unwrap(),
+                Contacts {
+                    current: vec![],
+                    requests_sent: vec![],
+                    requests_received: vec![user_1],
+                },
+            );
+
+            // User can't accept their own contact request
+            db.respond_to_contact_request(user_1, user_2, true)
+                .await
+                .unwrap_err();
+
+            // User accepts a contact request. Both users see the contact.
+            db.respond_to_contact_request(user_2, user_1, true)
+                .await
+                .unwrap();
+            assert_eq!(
+                db.get_contacts(user_1).await.unwrap(),
+                Contacts {
+                    current: vec![user_2],
+                    requests_sent: vec![],
+                    requests_received: vec![],
+                },
+            );
+            assert_eq!(
+                db.get_contacts(user_2).await.unwrap(),
+                Contacts {
+                    current: vec![user_1],
+                    requests_sent: vec![],
+                    requests_received: vec![],
+                },
+            );
+
+            // Users cannot re-request existing contacts.
+            db.send_contact_request(user_1, user_2).await.unwrap_err();
+            db.send_contact_request(user_2, user_1).await.unwrap_err();
+
+            // Users send each other concurrent contact requests and
+            // see that they are immediately accepted.
+            db.send_contact_request(user_1, user_3).await.unwrap();
+            db.send_contact_request(user_3, user_1).await.unwrap();
+            assert_eq!(
+                db.get_contacts(user_1).await.unwrap(),
+                Contacts {
+                    current: vec![user_2, user_3],
+                    requests_sent: vec![],
+                    requests_received: vec![],
+                },
+            );
+            assert_eq!(
+                db.get_contacts(user_3).await.unwrap(),
+                Contacts {
+                    current: vec![user_1],
+                    requests_sent: vec![],
+                    requests_received: vec![],
+                },
+            );
+
+            // User declines a contact request. Both users see that it is gone.
+            db.send_contact_request(user_2, user_3).await.unwrap();
+            db.respond_to_contact_request(user_3, user_2, false)
+                .await
+                .unwrap();
+            assert_eq!(
+                db.get_contacts(user_2).await.unwrap(),
+                Contacts {
+                    current: vec![user_1],
+                    requests_sent: vec![],
+                    requests_received: vec![],
+                },
+            );
+            assert_eq!(
+                db.get_contacts(user_3).await.unwrap(),
+                Contacts {
+                    current: vec![user_1],
+                    requests_sent: vec![],
+                    requests_received: vec![],
+                },
+            );
+        }
+    }
+
     pub struct TestDb {
         pub db: Option<Arc<dyn Db>>,
         pub url: String,
@@ -772,6 +1025,13 @@ pub mod tests {
         channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
         channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
         next_channel_message_id: Mutex<i32>,
+        contacts: Mutex<Vec<FakeContact>>,
+    }
+
+    struct FakeContact {
+        requester_id: UserId,
+        responder_id: UserId,
+        accepted: bool,
     }
 
     impl FakeDb {
@@ -788,6 +1048,7 @@ pub mod tests {
                 channel_memberships: Default::default(),
                 channel_messages: Default::default(),
                 next_channel_message_id: Mutex::new(1),
+                contacts: Default::default(),
             }
         }
     }
@@ -847,6 +1108,87 @@ pub mod tests {
             unimplemented!()
         }
 
+        async fn get_contacts(&self, id: UserId) -> Result<Contacts> {
+            self.background.simulate_random_delay().await;
+            let mut current = Vec::new();
+            let mut requests_sent = Vec::new();
+            let mut requests_received = Vec::new();
+            for contact in self.contacts.lock().iter() {
+                if contact.requester_id == id {
+                    if contact.accepted {
+                        current.push(contact.responder_id);
+                    } else {
+                        requests_sent.push(contact.responder_id);
+                    }
+                } else if contact.responder_id == id {
+                    if contact.accepted {
+                        current.push(contact.requester_id);
+                    } else {
+                        requests_received.push(contact.requester_id);
+                    }
+                }
+            }
+            Ok(Contacts {
+                current,
+                requests_sent,
+                requests_received,
+            })
+        }
+
+        async fn send_contact_request(
+            &self,
+            requester_id: UserId,
+            responder_id: UserId,
+        ) -> Result<()> {
+            let mut contacts = self.contacts.lock();
+            for contact in contacts.iter_mut() {
+                if contact.requester_id == requester_id && contact.responder_id == responder_id {
+                    if contact.accepted {
+                        Err(anyhow!("contact already exists"))?;
+                    } else {
+                        Err(anyhow!("contact already requested"))?;
+                    }
+                }
+                if contact.responder_id == requester_id && contact.requester_id == responder_id {
+                    if contact.accepted {
+                        Err(anyhow!("contact already exists"))?;
+                    } else {
+                        contact.accepted = true;
+                        return Ok(());
+                    }
+                }
+            }
+            contacts.push(FakeContact {
+                requester_id,
+                responder_id,
+                accepted: false,
+            });
+            Ok(())
+        }
+
+        async fn respond_to_contact_request(
+            &self,
+            responder_id: UserId,
+            requester_id: UserId,
+            accept: bool,
+        ) -> Result<()> {
+            let mut contacts = self.contacts.lock();
+            for (ix, contact) in contacts.iter_mut().enumerate() {
+                if contact.requester_id == requester_id && contact.responder_id == responder_id {
+                    if contact.accepted {
+                        return Err(anyhow!("contact already confirmed"));
+                    }
+                    if accept {
+                        contact.accepted = true;
+                    } else {
+                        contacts.remove(ix);
+                    }
+                    return Ok(());
+                }
+            }
+            Err(anyhow!("no such contact request"))
+        }
+
         async fn create_access_token_hash(
             &self,
             _user_id: UserId,