Detailed changes
@@ -1,6 +1,6 @@
use super::{http::HttpClient, proto, Client, Status, TypedEnvelope};
use anyhow::{anyhow, Result};
-use futures::{future, AsyncReadExt};
+use futures::{future, AsyncReadExt, Future};
use gpui::{AsyncAppContext, Entity, ImageData, ModelContext, ModelHandle, Task};
use postage::{prelude::Stream, sink::Sink, watch};
use rpc::proto::{RequestMessage, UsersResponse};
@@ -121,6 +121,13 @@ impl UserStore {
user_ids.insert(contact.user_id);
user_ids.extend(contact.projects.iter().flat_map(|w| &w.guests).copied());
}
+ user_ids.extend(message.pending_requests_to_user_ids.iter());
+ user_ids.extend(
+ message
+ .pending_requests_from_user_ids
+ .iter()
+ .map(|req| req.user_id),
+ );
let load_users = self.get_users(user_ids.into_iter().collect(), cx);
cx.spawn(|this, mut cx| async move {
@@ -153,6 +160,39 @@ impl UserStore {
.is_ok()
}
+ pub fn request_contact(&self, to_user_id: u64) -> impl Future<Output = Result<()>> {
+ let client = self.client.upgrade();
+ async move {
+ client
+ .ok_or_else(|| anyhow!("not logged in"))?
+ .request(proto::RequestContact { to_user_id })
+ .await?;
+ Ok(())
+ }
+ }
+
+ pub fn respond_to_contact_request(
+ &self,
+ from_user_id: u64,
+ accept: bool,
+ ) -> impl Future<Output = Result<()>> {
+ let client = self.client.upgrade();
+ async move {
+ client
+ .ok_or_else(|| anyhow!("not logged in"))?
+ .request(proto::RespondToContactRequest {
+ requesting_user_id: from_user_id,
+ response: if accept {
+ proto::ContactRequestResponse::Accept
+ } else {
+ proto::ContactRequestResponse::Reject
+ } as i32,
+ })
+ .await?;
+ Ok(())
+ }
+ }
+
pub fn get_users(
&mut self,
mut user_ids: Vec<u64>,
@@ -3,6 +3,7 @@ CREATE TABLE IF NOT EXISTS "contacts" (
"user_id_a" INTEGER REFERENCES users (id) NOT NULL,
"user_id_b" INTEGER REFERENCES users (id) NOT NULL,
"a_to_b" BOOLEAN NOT NULL,
+ "should_notify" BOOLEAN NOT NULL,
"accepted" BOOLEAN NOT NULL
);
@@ -19,6 +19,11 @@ pub trait Db: Send + Sync {
async fn get_contacts(&self, id: UserId) -> Result<Contacts>;
async fn send_contact_request(&self, requester_id: UserId, responder_id: UserId) -> Result<()>;
+ async fn dismiss_contact_request(
+ &self,
+ responder_id: UserId,
+ requester_id: UserId,
+ ) -> Result<()>;
async fn respond_to_contact_request(
&self,
responder_id: UserId,
@@ -184,12 +189,12 @@ impl Db for PostgresDb {
async fn get_contacts(&self, user_id: UserId) -> Result<Contacts> {
let query = "
- SELECT user_id_a, user_id_b, a_to_b, accepted
+ SELECT user_id_a, user_id_b, a_to_b, accepted, should_notify
FROM contacts
WHERE user_id_a = $1 OR user_id_b = $1;
";
- let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool)>(query)
+ let mut rows = sqlx::query_as::<_, (UserId, UserId, bool, bool, bool)>(query)
.bind(user_id)
.fetch(&self.pool);
@@ -197,7 +202,7 @@ impl Db for PostgresDb {
let mut requests_sent = Vec::new();
let mut requests_received = Vec::new();
while let Some(row) = rows.next().await {
- let (user_id_a, user_id_b, a_to_b, accepted) = row?;
+ let (user_id_a, user_id_b, a_to_b, accepted, should_notify) = row?;
if user_id_a == user_id {
if accepted {
@@ -205,13 +210,19 @@ impl Db for PostgresDb {
} else if a_to_b {
requests_sent.push(user_id_b);
} else {
- requests_received.push(user_id_b);
+ requests_received.push(IncomingContactRequest {
+ requesting_user_id: user_id_b,
+ should_notify,
+ });
}
} else {
if accepted {
current.push(user_id_a);
} else if a_to_b {
- requests_received.push(user_id_a);
+ requests_received.push(IncomingContactRequest {
+ requesting_user_id: user_id_a,
+ should_notify,
+ });
} else {
requests_sent.push(user_id_a);
}
@@ -232,8 +243,8 @@ impl Db for PostgresDb {
(receiver_id, sender_id, false)
};
let query = "
- INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted)
- VALUES ($1, $2, $3, 'f')
+ INSERT into contacts (user_id_a, user_id_b, a_to_b, accepted, should_notify)
+ VALUES ($1, $2, $3, 'f', 't')
ON CONFLICT (user_id_a, user_id_b) DO UPDATE
SET
accepted = 't'
@@ -270,7 +281,7 @@ impl Db for PostgresDb {
let result = if accept {
let query = "
UPDATE contacts
- SET accepted = 't'
+ SET accepted = 't', should_notify = 'f'
WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
";
sqlx::query(query)
@@ -298,6 +309,37 @@ impl Db for PostgresDb {
}
}
+ async fn dismiss_contact_request(
+ &self,
+ responder_id: UserId,
+ requester_id: UserId,
+ ) -> Result<()> {
+ let (id_a, id_b, a_to_b) = if responder_id < requester_id {
+ (responder_id, requester_id, false)
+ } else {
+ (requester_id, responder_id, true)
+ };
+
+ let query = "
+ UPDATE contacts
+ SET should_notify = 'f'
+ WHERE user_id_a = $1 AND user_id_b = $2 AND a_to_b = $3;
+ ";
+
+ let result = sqlx::query(query)
+ .bind(id_a.0)
+ .bind(id_b.0)
+ .bind(a_to_b)
+ .execute(&self.pool)
+ .await?;
+
+ if result.rows_affected() == 0 {
+ Err(anyhow!("no such contact request"))?;
+ }
+
+ Ok(())
+ }
+
// access tokens
async fn create_access_token_hash(
@@ -628,7 +670,13 @@ pub struct ChannelMessage {
pub struct Contacts {
pub current: Vec<UserId>,
pub requests_sent: Vec<UserId>,
- pub requests_received: Vec<UserId>,
+ pub requests_received: Vec<IncomingContactRequest>,
+}
+
+#[derive(Clone, Debug, PartialEq, Eq)]
+pub struct IncomingContactRequest {
+ pub requesting_user_id: UserId,
+ pub should_notify: bool,
}
fn fuzzy_like_string(string: &str) -> String {
@@ -886,7 +934,28 @@ pub mod tests {
Contacts {
current: vec![],
requests_sent: vec![],
- requests_received: vec![user_1],
+ requests_received: vec![IncomingContactRequest {
+ requesting_user_id: user_1,
+ should_notify: true
+ }],
+ },
+ );
+
+ // User 2 dismisses the contact request notification without accepting or rejecting.
+ // We shouldn't notify them again.
+ db.dismiss_contact_request(user_1, user_2)
+ .await
+ .unwrap_err();
+ db.dismiss_contact_request(user_2, user_1).await.unwrap();
+ assert_eq!(
+ db.get_contacts(user_2).await.unwrap(),
+ Contacts {
+ current: vec![],
+ requests_sent: vec![],
+ requests_received: vec![IncomingContactRequest {
+ requesting_user_id: user_1,
+ should_notify: false
+ }],
},
);
@@ -1032,6 +1101,7 @@ pub mod tests {
requester_id: UserId,
responder_id: UserId,
accepted: bool,
+ should_notify: bool,
}
impl FakeDb {
@@ -1124,7 +1194,10 @@ pub mod tests {
if contact.accepted {
current.push(contact.requester_id);
} else {
- requests_received.push(contact.requester_id);
+ requests_received.push(IncomingContactRequest {
+ requesting_user_id: contact.requester_id,
+ should_notify: contact.should_notify,
+ });
}
}
}
@@ -1162,10 +1235,29 @@ pub mod tests {
requester_id,
responder_id,
accepted: false,
+ should_notify: true,
});
Ok(())
}
+ async fn dismiss_contact_request(
+ &self,
+ responder_id: UserId,
+ requester_id: UserId,
+ ) -> Result<()> {
+ let mut contacts = self.contacts.lock();
+ for contact in contacts.iter_mut() {
+ if contact.requester_id == requester_id && contact.responder_id == responder_id {
+ if contact.accepted {
+ return Err(anyhow!("contact already confirmed"));
+ }
+ contact.should_notify = false;
+ return Ok(());
+ }
+ }
+ Err(anyhow!("no such contact request"))
+ }
+
async fn respond_to_contact_request(
&self,
responder_id: UserId,
@@ -154,6 +154,8 @@ impl Server {
.add_request_handler(Server::get_channels)
.add_request_handler(Server::get_users)
.add_request_handler(Server::fuzzy_search_users)
+ .add_request_handler(Server::request_contact)
+ .add_request_handler(Server::respond_to_contact_request)
.add_request_handler(Server::join_channel)
.add_message_handler(Server::leave_channel)
.add_request_handler(Server::send_channel_message)
@@ -914,6 +916,48 @@ impl Server {
Ok(())
}
+ async fn request_contact(
+ self: Arc<Server>,
+ request: TypedEnvelope<proto::RequestContact>,
+ response: Response<proto::RequestContact>,
+ ) -> Result<()> {
+ let requester_id = self
+ .store
+ .read()
+ .await
+ .user_id_for_connection(request.sender_id)?;
+ let responder_id = UserId::from_proto(request.payload.to_user_id);
+ self.app_state
+ .db
+ .send_contact_request(requester_id, responder_id)
+ .await?;
+ response.send(proto::Ack {})?;
+ Ok(())
+ }
+
+ async fn respond_to_contact_request(
+ self: Arc<Server>,
+ request: TypedEnvelope<proto::RespondToContactRequest>,
+ response: Response<proto::RespondToContactRequest>,
+ ) -> Result<()> {
+ let responder_id = self
+ .store
+ .read()
+ .await
+ .user_id_for_connection(request.sender_id)?;
+ let requester_id = UserId::from_proto(request.payload.requesting_user_id);
+ self.app_state
+ .db
+ .respond_to_contact_request(
+ responder_id,
+ requester_id,
+ request.payload.response == proto::ContactRequestResponse::Accept as i32,
+ )
+ .await?;
+ response.send(proto::Ack {})?;
+ Ok(())
+ }
+
#[instrument(skip(self, state, user_ids))]
fn update_contacts_for_users<'a>(
self: &Arc<Self>,
@@ -4911,6 +4955,33 @@ mod tests {
}
}
+ #[gpui::test(iterations = 10)]
+ async fn test_contacts_requests(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
+ cx_a.foreground().forbid_parking();
+
+ // Connect to a server as 3 clients.
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
+ let client_a = server.create_client(cx_a, "user_a").await;
+ let client_b = server.create_client(cx_b, "user_b").await;
+
+ client_a
+ .user_store
+ .read_with(cx_a, |store, _| {
+ store.request_contact(client_b.user_id().unwrap())
+ })
+ .await
+ .unwrap();
+
+ client_a.user_store.read_with(cx_a, |store, _| {
+ let contacts = store
+ .contacts()
+ .iter()
+ .map(|contact| contact.user.github_login.clone())
+ .collect::<Vec<_>>();
+ assert_eq!(contacts, &["user_b"])
+ });
+ }
+
#[gpui::test(iterations = 10)]
async fn test_following(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
cx_a.foreground().forbid_parking();
@@ -0,0 +1,67 @@
+#[macro_export]
+macro_rules! messages {
+ ($(($name:ident, $priority:ident)),* $(,)?) => {
+ pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
+ match envelope.payload {
+ $(Some(envelope::Payload::$name(payload)) => {
+ Some(Box::new(TypedEnvelope {
+ sender_id,
+ original_sender_id: envelope.original_sender_id.map(PeerId),
+ message_id: envelope.id,
+ payload,
+ }))
+ }, )*
+ _ => None
+ }
+ }
+
+ $(
+ impl EnvelopedMessage for $name {
+ const NAME: &'static str = std::stringify!($name);
+ const PRIORITY: MessagePriority = MessagePriority::$priority;
+
+ fn into_envelope(
+ self,
+ id: u32,
+ responding_to: Option<u32>,
+ original_sender_id: Option<u32>,
+ ) -> Envelope {
+ Envelope {
+ id,
+ responding_to,
+ original_sender_id,
+ payload: Some(envelope::Payload::$name(self)),
+ }
+ }
+
+ fn from_envelope(envelope: Envelope) -> Option<Self> {
+ if let Some(envelope::Payload::$name(msg)) = envelope.payload {
+ Some(msg)
+ } else {
+ None
+ }
+ }
+ }
+ )*
+ };
+}
+
+#[macro_export]
+macro_rules! request_messages {
+ ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
+ $(impl RequestMessage for $request_name {
+ type Response = $response_name;
+ })*
+ };
+}
+
+#[macro_export]
+macro_rules! entity_messages {
+ ($id_field:ident, $($name:ident),* $(,)?) => {
+ $(impl EntityMessage for $name {
+ fn remote_entity_id(&self) -> u64 {
+ self.$id_field
+ }
+ })*
+ };
+}
@@ -1,4 +1,4 @@
-use super::{ConnectionId, PeerId, TypedEnvelope};
+use super::{entity_messages, messages, request_messages, ConnectionId, PeerId, TypedEnvelope};
use anyhow::{anyhow, Result};
use async_tungstenite::tungstenite::Message as WebSocketMessage;
use futures::{SinkExt as _, StreamExt as _};
@@ -73,71 +73,6 @@ impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
}
}
-macro_rules! messages {
- ($(($name:ident, $priority:ident)),* $(,)?) => {
- pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
- match envelope.payload {
- $(Some(envelope::Payload::$name(payload)) => {
- Some(Box::new(TypedEnvelope {
- sender_id,
- original_sender_id: envelope.original_sender_id.map(PeerId),
- message_id: envelope.id,
- payload,
- }))
- }, )*
- _ => None
- }
- }
-
- $(
- impl EnvelopedMessage for $name {
- const NAME: &'static str = std::stringify!($name);
- const PRIORITY: MessagePriority = MessagePriority::$priority;
-
- fn into_envelope(
- self,
- id: u32,
- responding_to: Option<u32>,
- original_sender_id: Option<u32>,
- ) -> Envelope {
- Envelope {
- id,
- responding_to,
- original_sender_id,
- payload: Some(envelope::Payload::$name(self)),
- }
- }
-
- fn from_envelope(envelope: Envelope) -> Option<Self> {
- if let Some(envelope::Payload::$name(msg)) = envelope.payload {
- Some(msg)
- } else {
- None
- }
- }
- }
- )*
- };
-}
-
-macro_rules! request_messages {
- ($(($request_name:ident, $response_name:ident)),* $(,)?) => {
- $(impl RequestMessage for $request_name {
- type Response = $response_name;
- })*
- };
-}
-
-macro_rules! entity_messages {
- ($id_field:ident, $($name:ident),* $(,)?) => {
- $(impl EntityMessage for $name {
- fn remote_entity_id(&self) -> u64 {
- self.$id_field
- }
- })*
- };
-}
-
messages!(
(Ack, Foreground),
(AddProjectCollaborator, Foreground),
@@ -198,6 +133,8 @@ messages!(
(ReloadBuffersResponse, Foreground),
(RemoveProjectCollaborator, Foreground),
(RenameProjectEntry, Foreground),
+ (RequestContact, Foreground),
+ (RespondToContactRequest, Foreground),
(SaveBuffer, Foreground),
(SearchProject, Background),
(SearchProjectResponse, Background),
@@ -250,6 +187,8 @@ request_messages!(
(RegisterProject, RegisterProjectResponse),
(RegisterWorktree, Ack),
(ReloadBuffers, ReloadBuffersResponse),
+ (RequestContact, Ack),
+ (RespondToContactRequest, Ack),
(RenameProjectEntry, ProjectEntryResponse),
(SaveBuffer, BufferSaved),
(SearchProject, SearchProjectResponse),
@@ -4,5 +4,6 @@ mod peer;
pub mod proto;
pub use conn::Connection;
pub use peer::*;
+mod macros;
pub const PROTOCOL_VERSION: u32 = 16;