diff --git a/crates/collab/migrations/20220506130724_create_contacts.sql b/crates/collab/migrations/20220506130724_create_contacts.sql index a292674ac47b0c3cb8536f68d53fe25d3364b2ac..216635b3195283c8a8a6cbf5319b572f45b7b04b 100644 --- a/crates/collab/migrations/20220506130724_create_contacts.sql +++ b/crates/collab/migrations/20220506130724_create_contacts.sql @@ -1,10 +1,10 @@ CREATE TABLE IF NOT EXISTS "contacts" ( "id" SERIAL PRIMARY KEY, - "requesting_user_id" INTEGER REFERENCES users (id) NOT NULL, - "receiving_user_id" INTEGER REFERENCES users (id) NOT NULL, - "accepted" BOOLEAN NOT NULL, - "blocked" BOOLEAN NOT NULL + "user_id_a" INTEGER REFERENCES users (id) NOT NULL, + "user_id_b" INTEGER REFERENCES users (id) NOT NULL, + "a_to_b" BOOLEAN NOT NULL, + "accepted" BOOLEAN NOT NULL ); -CREATE UNIQUE INDEX "index_org_contacts_requesting_user_id_and_receiving_user_id" ON "contacts" ("requesting_user_id", "receiving_user_id"); -CREATE UNIQUE INDEX "index_org_contacts_receiving_user" ON "contacts" ("receiving_user_id"); +CREATE UNIQUE INDEX "index_contacts_user_ids" ON "contacts" ("user_id_a", "user_id_b"); +CREATE INDEX "index_contacts_user_id_b" ON "contacts" ("user_id_b"); diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 3eb6fd0240ebe32c09d341ba60dc037c7ae3ab03..0f2c700c2cc5a9fd540dc51a19f2af01412b5d33 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -1,6 +1,6 @@ -use anyhow::Context; -use anyhow::Result; +use anyhow::{anyhow, Context, Result}; use async_trait::async_trait; +use futures::StreamExt; use serde::Serialize; pub use sqlx::postgres::PgPoolOptions as DbOptions; use sqlx::{types::Uuid, FromRow}; @@ -16,6 +16,16 @@ pub trait Db: Send + Sync { async fn get_user_by_github_login(&self, github_login: &str) -> Result>; async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>; async fn destroy_user(&self, id: UserId) -> Result<()>; + + async fn get_contacts(&self, id: UserId) -> Result; + async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>; + async fn respond_to_contact_request( + &self, + responder_id: UserId, + requester_id: UserId, + accept: bool, + ) -> Result<()>; + async fn create_access_token_hash( &self, user_id: UserId, @@ -24,6 +34,7 @@ pub trait Db: Send + Sync { ) -> Result<()>; async fn get_access_token_hashes(&self, user_id: UserId) -> Result>; #[cfg(any(test, feature = "seed-support"))] + async fn find_org_by_slug(&self, slug: &str) -> Result>; #[cfg(any(test, feature = "seed-support"))] async fn create_org(&self, name: &str, slug: &str) -> Result; @@ -32,6 +43,7 @@ pub trait Db: Send + Sync { #[cfg(any(test, feature = "seed-support"))] async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result; #[cfg(any(test, feature = "seed-support"))] + async fn get_org_channels(&self, org_id: OrgId) -> Result>; async fn get_accessible_channels(&self, user_id: UserId) -> Result>; async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId) @@ -168,6 +180,124 @@ impl Db for PostgresDb { .map(drop)?) } + // contacts + + async fn get_contacts(&self, user_id: UserId) -> Result { + let query = " + SELECT user_id_a, user_id_b, a_to_b, accepted + FROM contacts + WHERE user_id_a = $1 OR user_id_b = $1; + "; + + let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool)>(query) + .bind(user_id) + .fetch(&self.pool); + + let mut current = Vec::new(); + let mut requests_sent = Vec::new(); + let mut requests_received = Vec::new(); + while let Some(row) = rows.next().await { + let (user_id_a, user_id_b, a_to_b, accepted) = row?; + + if user_id_a == user_id { + if accepted { + current.push(user_id_b); + } else if a_to_b { + requests_sent.push(user_id_b); + } else { + requests_received.push(user_id_b); + } + } else { + if accepted { + current.push(user_id_a); + } else if a_to_b { + requests_received.push(user_id_a); + } else { + requests_sent.push(user_id_a); + } + } + } + + Ok(Contacts { + current, + requests_sent, + requests_received, + }) + } + + async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { + let (id_a, id_b, a_to_b) = if sender_id < receiver_id { + (sender_id, receiver_id, true) + } else { + (receiver_id, sender_id, false) + }; + let query = " + INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted) + VALUES ($1, $2, $3, 'f') + ON CONFLICT (user_id_a, user_id_b) DO UPDATE + SET + accepted = 't' + WHERE + NOT contacts.accepted AND + ((contacts.a_to_b = excluded.a_to_b AND contacts.user_id_a = excluded.user_id_b) OR + (contacts.a_to_b != excluded.a_to_b AND contacts.user_id_a = excluded.user_id_a)); + "; + let result = sqlx::query(query) + .bind(id_a.0) + .bind(id_b.0) + .bind(a_to_b) + .execute(&self.pool) + .await?; + + if result.rows_affected() == 1 { + Ok(()) + } else { + Err(anyhow!("contact already requested")) + } + } + + async fn respond_to_contact_request( + &self, + responder_id: UserId, + requester_id: UserId, + accept: bool, + ) -> Result<()> { + let (id_a, id_b, a_to_b) = if responder_id < requester_id { + (responder_id, requester_id, false) + } else { + (requester_id, responder_id, true) + }; + let result = if accept { + let query = " + UPDATE contacts + SET accepted = 't' + WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3; + "; + sqlx::query(query) + .bind(id_a.0) + .bind(id_b.0) + .bind(a_to_b) + .execute(&self.pool) + .await? + } else { + let query = " + DELETE FROM contacts + WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3 AND NOT accepted; + "; + sqlx::query(query) + .bind(id_a.0) + .bind(id_b.0) + .bind(a_to_b) + .execute(&self.pool) + .await? + }; + if result.rows_affected() == 1 { + Ok(()) + } else { + Err(anyhow!("no such contact request")) + } + } + // access tokens async fn create_access_token_hash( @@ -494,6 +624,13 @@ pub struct ChannelMessage { pub nonce: Uuid, } +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Contacts { + pub current: Vec, + pub requests_sent: Vec, + pub requests_received: Vec, +} + fn fuzzy_like_string(string: &str) -> String { let mut result = String::with_capacity(string.len() * 2 + 1); for c in string.chars() { @@ -712,6 +849,122 @@ pub mod tests { } } + #[tokio::test(flavor = "multi_thread")] + async fn test_add_contacts() { + for test_db in [ + TestDb::postgres().await, + TestDb::fake(Arc::new(gpui::executor::Background::new())), + ] { + let db = test_db.db(); + + let user_1 = db.create_user("user1", false).await.unwrap(); + let user_2 = db.create_user("user2", false).await.unwrap(); + let user_3 = db.create_user("user3", false).await.unwrap(); + + // User starts with no contacts + assert_eq!( + db.get_contacts(user_1).await.unwrap(), + Contacts { + current: vec![], + requests_sent: vec![], + requests_received: vec![], + }, + ); + + // User requests a contact. Both users see the pending request. + db.send_contact_request(user_1, user_2).await.unwrap(); + assert_eq!( + db.get_contacts(user_1).await.unwrap(), + Contacts { + current: vec![], + requests_sent: vec![user_2], + requests_received: vec![], + }, + ); + assert_eq!( + db.get_contacts(user_2).await.unwrap(), + Contacts { + current: vec![], + requests_sent: vec![], + requests_received: vec![user_1], + }, + ); + + // User can't accept their own contact request + db.respond_to_contact_request(user_1, user_2, true) + .await + .unwrap_err(); + + // User accepts a contact request. Both users see the contact. + db.respond_to_contact_request(user_2, user_1, true) + .await + .unwrap(); + assert_eq!( + db.get_contacts(user_1).await.unwrap(), + Contacts { + current: vec![user_2], + requests_sent: vec![], + requests_received: vec![], + }, + ); + assert_eq!( + db.get_contacts(user_2).await.unwrap(), + Contacts { + current: vec![user_1], + requests_sent: vec![], + requests_received: vec![], + }, + ); + + // Users cannot re-request existing contacts. + db.send_contact_request(user_1, user_2).await.unwrap_err(); + db.send_contact_request(user_2, user_1).await.unwrap_err(); + + // Users send each other concurrent contact requests and + // see that they are immediately accepted. + db.send_contact_request(user_1, user_3).await.unwrap(); + db.send_contact_request(user_3, user_1).await.unwrap(); + assert_eq!( + db.get_contacts(user_1).await.unwrap(), + Contacts { + current: vec![user_2, user_3], + requests_sent: vec![], + requests_received: vec![], + }, + ); + assert_eq!( + db.get_contacts(user_3).await.unwrap(), + Contacts { + current: vec![user_1], + requests_sent: vec![], + requests_received: vec![], + }, + ); + + // User declines a contact request. Both users see that it is gone. + db.send_contact_request(user_2, user_3).await.unwrap(); + db.respond_to_contact_request(user_3, user_2, false) + .await + .unwrap(); + assert_eq!( + db.get_contacts(user_2).await.unwrap(), + Contacts { + current: vec![user_1], + requests_sent: vec![], + requests_received: vec![], + }, + ); + assert_eq!( + db.get_contacts(user_3).await.unwrap(), + Contacts { + current: vec![user_1], + requests_sent: vec![], + requests_received: vec![], + }, + ); + } + } + pub struct TestDb { pub db: Option>, pub url: String, @@ -772,6 +1025,13 @@ pub mod tests { channel_memberships: Mutex>, channel_messages: Mutex>, next_channel_message_id: Mutex, + contacts: Mutex>, + } + + struct FakeContact { + requester_id: UserId, + responder_id: UserId, + accepted: bool, } impl FakeDb { @@ -788,6 +1048,7 @@ pub mod tests { channel_memberships: Default::default(), channel_messages: Default::default(), next_channel_message_id: Mutex::new(1), + contacts: Default::default(), } } } @@ -847,6 +1108,87 @@ pub mod tests { unimplemented!() } + async fn get_contacts(&self, id: UserId) -> Result { + self.background.simulate_random_delay().await; + let mut current = Vec::new(); + let mut requests_sent = Vec::new(); + let mut requests_received = Vec::new(); + for contact in self.contacts.lock().iter() { + if contact.requester_id == id { + if contact.accepted { + current.push(contact.responder_id); + } else { + requests_sent.push(contact.responder_id); + } + } else if contact.responder_id == id { + if contact.accepted { + current.push(contact.requester_id); + } else { + requests_received.push(contact.requester_id); + } + } + } + Ok(Contacts { + current, + requests_sent, + requests_received, + }) + } + + async fn send_contact_request( + &self, + requester_id: UserId, + responder_id: UserId, + ) -> Result<()> { + let mut contacts = self.contacts.lock(); + for contact in contacts.iter_mut() { + if contact.requester_id == requester_id && contact.responder_id == responder_id { + if contact.accepted { + Err(anyhow!("contact already exists"))?; + } else { + Err(anyhow!("contact already requested"))?; + } + } + if contact.responder_id == requester_id && contact.requester_id == responder_id { + if contact.accepted { + Err(anyhow!("contact already exists"))?; + } else { + contact.accepted = true; + return Ok(()); + } + } + } + contacts.push(FakeContact { + requester_id, + responder_id, + accepted: false, + }); + Ok(()) + } + + async fn respond_to_contact_request( + &self, + responder_id: UserId, + requester_id: UserId, + accept: bool, + ) -> Result<()> { + let mut contacts = self.contacts.lock(); + for (ix, contact) in contacts.iter_mut().enumerate() { + if contact.requester_id == requester_id && contact.responder_id == responder_id { + if contact.accepted { + return Err(anyhow!("contact already confirmed")); + } + if accept { + contact.accepted = true; + } else { + contacts.remove(ix); + } + return Ok(()); + } + } + Err(anyhow!("no such contact request")) + } + async fn create_access_token_hash( &self, _user_id: UserId,