diff --git a/Cargo.lock b/Cargo.lock index 5fe28590a14fe4e8874022972a3baf3e7d6b7c4f..e817fed0dbfe37f8ac0ba05bf0f9de79a1ef8997 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1504,6 +1504,7 @@ dependencies = [ "lsp", "nanoid", "node_runtime", + "notifications", "parking_lot 0.11.2", "pretty_assertions", "project", @@ -1559,13 +1560,17 @@ dependencies = [ "fuzzy", "gpui", "language", + "lazy_static", "log", "menu", + "notifications", "picker", "postage", + "pretty_assertions", "project", "recent_projects", "rich_text", + "rpc", "schemars", "serde", "serde_derive", @@ -1573,6 +1578,7 @@ dependencies = [ "theme", "theme_selector", "time", + "tree-sitter-markdown", "util", "vcs_menu", "workspace", @@ -4730,6 +4736,26 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "notifications" +version = "0.1.0" +dependencies = [ + "anyhow", + "channel", + "client", + "clock", + "collections", + "db", + "feature_flags", + "gpui", + "rpc", + "settings", + "sum_tree", + "text", + "time", + "util", +] + [[package]] name = "ntapi" version = "0.3.7" @@ -6404,8 +6430,10 @@ dependencies = [ "rsa 0.4.0", "serde", "serde_derive", + "serde_json", "smol", "smol-timeout", + "strum", "tempdir", "tracing", "util", @@ -6626,6 +6654,12 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustversion" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ffc183a10b4478d04cbbbfc96d0873219d962dd5accaff2ffbd4ceb7df837f4" + [[package]] name = "rustybuzz" version = "0.3.0" @@ -7700,6 +7734,22 @@ name = "strum" version = "0.25.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "290d54ea6f91c969195bdbcd7442c8c2a2ba87da8bf60a7ee86a235d4bc1e125" +dependencies = [ + "strum_macros", +] + +[[package]] +name = "strum_macros" +version = "0.25.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ad8d03b598d3d0fff69bf533ee3ef19b8eeb342729596df84bcc7e1f96ec4059" +dependencies = [ + "heck 0.4.1", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.37", +] [[package]] name = "subtle" @@ -10098,6 +10148,7 @@ dependencies = [ "log", "lsp", "node_runtime", + "notifications", "num_cpus", "outline", "parking_lot 0.11.2", diff --git a/Cargo.toml b/Cargo.toml index 995cd15edd45d6c983fa20c19c6ba82749300260..cf977b8fe6dcecd39641f07802001afa1456d220 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -47,6 +47,7 @@ members = [ "crates/media", "crates/menu", "crates/node_runtime", + "crates/notifications", "crates/outline", "crates/picker", "crates/plugin", @@ -112,6 +113,7 @@ serde_derive = { version = "1.0", features = ["deserialize_in_place"] } serde_json = { version = "1.0", features = ["preserve_order", "raw_value"] } smallvec = { version = "1.6", features = ["union"] } smol = { version = "1.2" } +strum = { version = "0.25.0", features = ["derive"] } sysinfo = "0.29.10" tempdir = { version = "0.3.7" } thiserror = { version = "1.0.29" } diff --git a/assets/icons/bell.svg b/assets/icons/bell.svg new file mode 100644 index 0000000000000000000000000000000000000000..ea1c6dd42e8821b632f6de97d143a7b9f4b97fd2 --- /dev/null +++ b/assets/icons/bell.svg @@ -0,0 +1,8 @@ + + + diff --git a/assets/settings/default.json b/assets/settings/default.json index 4143e5dd41f1a1596ea532433d91cfaffe35f92a..e70b56335915c8b4b2397dcae73def3d0a7bcda3 100644 --- a/assets/settings/default.json +++ b/assets/settings/default.json @@ -142,6 +142,14 @@ // Default width of the channels panel. "default_width": 240 }, + "notification_panel": { + // Whether to show the collaboration panel button in the status bar. + "button": true, + // Where to dock channels panel. Can be 'left' or 'right'. + "dock": "right", + // Default width of the channels panel. + "default_width": 240 + }, "assistant": { // Whether to show the assistant panel button in the status bar. "button": true, diff --git a/crates/channel/src/channel.rs b/crates/channel/src/channel.rs index d31d4b3c8c9e77e94661835c06ea234c70ded416..b6db304a70c31deab55aec61d6f5912f8bfd3e20 100644 --- a/crates/channel/src/channel.rs +++ b/crates/channel/src/channel.rs @@ -7,7 +7,10 @@ use gpui::{AppContext, ModelHandle}; use std::sync::Arc; pub use channel_buffer::{ChannelBuffer, ChannelBufferEvent, ACKNOWLEDGE_DEBOUNCE_INTERVAL}; -pub use channel_chat::{ChannelChat, ChannelChatEvent, ChannelMessage, ChannelMessageId}; +pub use channel_chat::{ + mentions_to_proto, ChannelChat, ChannelChatEvent, ChannelMessage, ChannelMessageId, + MessageParams, +}; pub use channel_store::{ Channel, ChannelData, ChannelEvent, ChannelId, ChannelMembership, ChannelPath, ChannelStore, }; diff --git a/crates/channel/src/channel_chat.rs b/crates/channel/src/channel_chat.rs index 734182886b3bebeacd03dbc177bf8ffcb8ab64e2..ca344c409f5df1d09c830fbecc5b649fbdd3d844 100644 --- a/crates/channel/src/channel_chat.rs +++ b/crates/channel/src/channel_chat.rs @@ -3,12 +3,17 @@ use anyhow::{anyhow, Result}; use client::{ proto, user::{User, UserStore}, - Client, Subscription, TypedEnvelope, + Client, Subscription, TypedEnvelope, UserId, }; use futures::lock::Mutex; use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, Task}; use rand::prelude::*; -use std::{collections::HashSet, mem, ops::Range, sync::Arc}; +use std::{ + collections::HashSet, + mem, + ops::{ControlFlow, Range}, + sync::Arc, +}; use sum_tree::{Bias, SumTree}; use time::OffsetDateTime; use util::{post_inc, ResultExt as _, TryFutureExt}; @@ -16,6 +21,7 @@ use util::{post_inc, ResultExt as _, TryFutureExt}; pub struct ChannelChat { channel: Arc, messages: SumTree, + acknowledged_message_ids: HashSet, channel_store: ModelHandle, loaded_all_messages: bool, last_acknowledged_id: Option, @@ -27,6 +33,12 @@ pub struct ChannelChat { _subscription: Subscription, } +#[derive(Debug, PartialEq, Eq)] +pub struct MessageParams { + pub text: String, + pub mentions: Vec<(Range, UserId)>, +} + #[derive(Clone, Debug)] pub struct ChannelMessage { pub id: ChannelMessageId, @@ -34,6 +46,7 @@ pub struct ChannelMessage { pub timestamp: OffsetDateTime, pub sender: Arc, pub nonce: u128, + pub mentions: Vec<(Range, UserId)>, } #[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -105,6 +118,7 @@ impl ChannelChat { rpc: client, outgoing_messages_lock: Default::default(), messages: Default::default(), + acknowledged_message_ids: Default::default(), loaded_all_messages, next_pending_message_id: 0, last_acknowledged_id: None, @@ -120,12 +134,16 @@ impl ChannelChat { &self.channel } + pub fn client(&self) -> &Arc { + &self.rpc + } + pub fn send_message( &mut self, - body: String, + message: MessageParams, cx: &mut ModelContext, - ) -> Result>> { - if body.is_empty() { + ) -> Result>> { + if message.text.is_empty() { Err(anyhow!("message body can't be empty"))?; } @@ -142,9 +160,10 @@ impl ChannelChat { SumTree::from_item( ChannelMessage { id: pending_id, - body: body.clone(), + body: message.text.clone(), sender: current_user, timestamp: OffsetDateTime::now_utc(), + mentions: message.mentions.clone(), nonce, }, &(), @@ -158,20 +177,18 @@ impl ChannelChat { let outgoing_message_guard = outgoing_messages_lock.lock().await; let request = rpc.request(proto::SendChannelMessage { channel_id, - body, + body: message.text, nonce: Some(nonce.into()), + mentions: mentions_to_proto(&message.mentions), }); let response = request.await?; drop(outgoing_message_guard); - let message = ChannelMessage::from_proto( - response.message.ok_or_else(|| anyhow!("invalid message"))?, - &user_store, - &mut cx, - ) - .await?; + let response = response.message.ok_or_else(|| anyhow!("invalid message"))?; + let id = response.id; + let message = ChannelMessage::from_proto(response, &user_store, &mut cx).await?; this.update(&mut cx, |this, cx| { this.insert_messages(SumTree::from_item(message, &()), cx); - Ok(()) + Ok(id) }) })) } @@ -191,41 +208,76 @@ impl ChannelChat { }) } - pub fn load_more_messages(&mut self, cx: &mut ModelContext) -> bool { - if !self.loaded_all_messages { - let rpc = self.rpc.clone(); - let user_store = self.user_store.clone(); - let channel_id = self.channel.id; - if let Some(before_message_id) = - self.messages.first().and_then(|message| match message.id { - ChannelMessageId::Saved(id) => Some(id), - ChannelMessageId::Pending(_) => None, - }) - { - cx.spawn(|this, mut cx| { - async move { - let response = rpc - .request(proto::GetChannelMessages { - channel_id, - before_message_id, - }) - .await?; - let loaded_all_messages = response.done; - let messages = - messages_from_proto(response.messages, &user_store, &mut cx).await?; - this.update(&mut cx, |this, cx| { - this.loaded_all_messages = loaded_all_messages; - this.insert_messages(messages, cx); - }); - anyhow::Ok(()) + pub fn load_more_messages(&mut self, cx: &mut ModelContext) -> Option>> { + if self.loaded_all_messages { + return None; + } + + let rpc = self.rpc.clone(); + let user_store = self.user_store.clone(); + let channel_id = self.channel.id; + let before_message_id = self.first_loaded_message_id()?; + Some(cx.spawn(|this, mut cx| { + async move { + let response = rpc + .request(proto::GetChannelMessages { + channel_id, + before_message_id, + }) + .await?; + let loaded_all_messages = response.done; + let messages = messages_from_proto(response.messages, &user_store, &mut cx).await?; + this.update(&mut cx, |this, cx| { + this.loaded_all_messages = loaded_all_messages; + this.insert_messages(messages, cx); + }); + anyhow::Ok(()) + } + .log_err() + })) + } + + pub fn first_loaded_message_id(&mut self) -> Option { + self.messages.first().and_then(|message| match message.id { + ChannelMessageId::Saved(id) => Some(id), + ChannelMessageId::Pending(_) => None, + }) + } + + /// Load all of the chat messages since a certain message id. + /// + /// For now, we always maintain a suffix of the channel's messages. + pub async fn load_history_since_message( + chat: ModelHandle, + message_id: u64, + mut cx: AsyncAppContext, + ) -> Option { + loop { + let step = chat.update(&mut cx, |chat, cx| { + if let Some(first_id) = chat.first_loaded_message_id() { + if first_id <= message_id { + let mut cursor = chat.messages.cursor::<(ChannelMessageId, Count)>(); + let message_id = ChannelMessageId::Saved(message_id); + cursor.seek(&message_id, Bias::Left, &()); + return ControlFlow::Break( + if cursor + .item() + .map_or(false, |message| message.id == message_id) + { + Some(cursor.start().1 .0) + } else { + None + }, + ); } - .log_err() - }) - .detach(); - return true; + } + ControlFlow::Continue(chat.load_more_messages(cx)) + }); + match step { + ControlFlow::Break(ix) => return ix, + ControlFlow::Continue(task) => task?.await?, } } - false } pub fn acknowledge_last_message(&mut self, cx: &mut ModelContext) { @@ -284,6 +336,7 @@ impl ChannelChat { let request = rpc.request(proto::SendChannelMessage { channel_id, body: pending_message.body, + mentions: mentions_to_proto(&pending_message.mentions), nonce: Some(pending_message.nonce.into()), }); let response = request.await?; @@ -319,6 +372,17 @@ impl ChannelChat { cursor.item().unwrap() } + pub fn acknowledge_message(&mut self, id: u64) { + if self.acknowledged_message_ids.insert(id) { + self.rpc + .send(proto::AckChannelMessage { + channel_id: self.channel.id, + message_id: id, + }) + .ok(); + } + } + pub fn messages_in_range(&self, range: Range) -> impl Iterator { let mut cursor = self.messages.cursor::(); cursor.seek(&Count(range.start), Bias::Right, &()); @@ -451,22 +515,7 @@ async fn messages_from_proto( user_store: &ModelHandle, cx: &mut AsyncAppContext, ) -> Result> { - let unique_user_ids = proto_messages - .iter() - .map(|m| m.sender_id) - .collect::>() - .into_iter() - .collect(); - user_store - .update(cx, |user_store, cx| { - user_store.get_users(unique_user_ids, cx) - }) - .await?; - - let mut messages = Vec::with_capacity(proto_messages.len()); - for message in proto_messages { - messages.push(ChannelMessage::from_proto(message, user_store, cx).await?); - } + let messages = ChannelMessage::from_proto_vec(proto_messages, user_store, cx).await?; let mut result = SumTree::new(); result.extend(messages, &()); Ok(result) @@ -486,6 +535,14 @@ impl ChannelMessage { Ok(ChannelMessage { id: ChannelMessageId::Saved(message.id), body: message.body, + mentions: message + .mentions + .into_iter() + .filter_map(|mention| { + let range = mention.range?; + Some((range.start as usize..range.end as usize, mention.user_id)) + }) + .collect(), timestamp: OffsetDateTime::from_unix_timestamp(message.timestamp as i64)?, sender, nonce: message @@ -498,6 +555,43 @@ impl ChannelMessage { pub fn is_pending(&self) -> bool { matches!(self.id, ChannelMessageId::Pending(_)) } + + pub async fn from_proto_vec( + proto_messages: Vec, + user_store: &ModelHandle, + cx: &mut AsyncAppContext, + ) -> Result> { + let unique_user_ids = proto_messages + .iter() + .map(|m| m.sender_id) + .collect::>() + .into_iter() + .collect(); + user_store + .update(cx, |user_store, cx| { + user_store.get_users(unique_user_ids, cx) + }) + .await?; + + let mut messages = Vec::with_capacity(proto_messages.len()); + for message in proto_messages { + messages.push(ChannelMessage::from_proto(message, user_store, cx).await?); + } + Ok(messages) + } +} + +pub fn mentions_to_proto(mentions: &[(Range, UserId)]) -> Vec { + mentions + .iter() + .map(|(range, user_id)| proto::ChatMention { + range: Some(proto::Range { + start: range.start as u64, + end: range.end as u64, + }), + user_id: *user_id as u64, + }) + .collect() } impl sum_tree::Item for ChannelMessage { @@ -538,3 +632,12 @@ impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count { self.0 += summary.count; } } + +impl<'a> From<&'a str> for MessageParams { + fn from(value: &'a str) -> Self { + Self { + text: value.into(), + mentions: Vec::new(), + } + } +} diff --git a/crates/channel/src/channel_store.rs b/crates/channel/src/channel_store.rs index 9c80dcc2b742b07430a752286a2008bc5cfd05b2..221b84529706a36aa22371fedf2e8ceb5cc11987 100644 --- a/crates/channel/src/channel_store.rs +++ b/crates/channel/src/channel_store.rs @@ -1,6 +1,6 @@ mod channel_index; -use crate::{channel_buffer::ChannelBuffer, channel_chat::ChannelChat}; +use crate::{channel_buffer::ChannelBuffer, channel_chat::ChannelChat, ChannelMessage}; use anyhow::{anyhow, Result}; use channel_index::ChannelIndex; use client::{Client, Subscription, User, UserId, UserStore}; @@ -153,9 +153,6 @@ impl ChannelStore { this.update(&mut cx, |this, cx| this.handle_disconnect(true, cx)); } } - if status.is_connected() { - } else { - } } Some(()) }); @@ -242,6 +239,12 @@ impl ChannelStore { self.channel_index.by_id().values().nth(ix) } + pub fn has_channel_invitation(&self, channel_id: ChannelId) -> bool { + self.channel_invitations + .iter() + .any(|channel| channel.id == channel_id) + } + pub fn channel_invitations(&self) -> &[Arc] { &self.channel_invitations } @@ -274,6 +277,33 @@ impl ChannelStore { ) } + pub fn fetch_channel_messages( + &self, + message_ids: Vec, + cx: &mut ModelContext, + ) -> Task>> { + let request = if message_ids.is_empty() { + None + } else { + Some( + self.client + .request(proto::GetChannelMessagesById { message_ids }), + ) + }; + cx.spawn_weak(|this, mut cx| async move { + if let Some(request) = request { + let response = request.await?; + let this = this + .upgrade(&cx) + .ok_or_else(|| anyhow!("channel store dropped"))?; + let user_store = this.read_with(&cx, |this, _| this.user_store.clone()); + ChannelMessage::from_proto_vec(response.messages, &user_store, &mut cx).await + } else { + Ok(Vec::new()) + } + }) + } + pub fn has_channel_buffer_changed(&self, channel_id: ChannelId) -> Option { self.channel_index .by_id() @@ -694,14 +724,15 @@ impl ChannelStore { &mut self, channel_id: ChannelId, accept: bool, - ) -> impl Future> { + cx: &mut ModelContext, + ) -> Task> { let client = self.client.clone(); - async move { + cx.background().spawn(async move { client .request(proto::RespondToChannelInvite { channel_id, accept }) .await?; Ok(()) - } + }) } pub fn get_channel_member_details( diff --git a/crates/channel/src/channel_store_tests.rs b/crates/channel/src/channel_store_tests.rs index 23f2e11a03e3fdd577412c5382ff724212e09e0f..8cc9cb73daed6f32280b3e2d0c61b5235173de51 100644 --- a/crates/channel/src/channel_store_tests.rs +++ b/crates/channel/src/channel_store_tests.rs @@ -202,6 +202,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { body: "a".into(), timestamp: 1000, sender_id: 5, + mentions: vec![], nonce: Some(1.into()), }, proto::ChannelMessage { @@ -209,6 +210,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { body: "b".into(), timestamp: 1001, sender_id: 6, + mentions: vec![], nonce: Some(2.into()), }, ], @@ -255,6 +257,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { body: "c".into(), timestamp: 1002, sender_id: 7, + mentions: vec![], nonce: Some(3.into()), }), }); @@ -292,7 +295,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { // Scroll up to view older messages. channel.update(cx, |channel, cx| { - assert!(channel.load_more_messages(cx)); + channel.load_more_messages(cx).unwrap().detach(); }); let get_messages = server.receive::().await.unwrap(); assert_eq!(get_messages.payload.channel_id, 5); @@ -308,6 +311,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { timestamp: 998, sender_id: 5, nonce: Some(4.into()), + mentions: vec![], }, proto::ChannelMessage { id: 9, @@ -315,6 +319,7 @@ async fn test_channel_messages(cx: &mut TestAppContext) { timestamp: 999, sender_id: 6, nonce: Some(5.into()), + mentions: vec![], }, ], }, diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index 6aa41708e3ae3e3c3504ab82791278c8a1837c0a..8299b7c6e4bc9898f055dde0d3ad3b68d172fe69 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -293,21 +293,19 @@ impl UserStore { // No need to paralellize here let mut updated_contacts = Vec::new(); for contact in message.contacts { - let should_notify = contact.should_notify; - updated_contacts.push(( - Arc::new(Contact::from_proto(contact, &this, &mut cx).await?), - should_notify, + updated_contacts.push(Arc::new( + Contact::from_proto(contact, &this, &mut cx).await?, )); } let mut incoming_requests = Vec::new(); for request in message.incoming_requests { - incoming_requests.push({ - let user = this - .update(&mut cx, |this, cx| this.get_user(request.requester_id, cx)) - .await?; - (user, request.should_notify) - }); + incoming_requests.push( + this.update(&mut cx, |this, cx| { + this.get_user(request.requester_id, cx) + }) + .await?, + ); } let mut outgoing_requests = Vec::new(); @@ -330,13 +328,7 @@ impl UserStore { this.contacts .retain(|contact| !removed_contacts.contains(&contact.user.id)); // Update existing contacts and insert new ones - for (updated_contact, should_notify) in updated_contacts { - if should_notify { - cx.emit(Event::Contact { - user: updated_contact.user.clone(), - kind: ContactEventKind::Accepted, - }); - } + for updated_contact in updated_contacts { match this.contacts.binary_search_by_key( &&updated_contact.user.github_login, |contact| &contact.user.github_login, @@ -359,14 +351,7 @@ impl UserStore { } }); // Update existing incoming requests and insert new ones - for (user, should_notify) in incoming_requests { - if should_notify { - cx.emit(Event::Contact { - user: user.clone(), - kind: ContactEventKind::Requested, - }); - } - + for user in incoming_requests { match this .incoming_contact_requests .binary_search_by_key(&&user.github_login, |contact| { @@ -415,6 +400,12 @@ impl UserStore { &self.incoming_contact_requests } + pub fn has_incoming_contact_request(&self, user_id: u64) -> bool { + self.incoming_contact_requests + .iter() + .any(|user| user.id == user_id) + } + pub fn outgoing_contact_requests(&self) -> &[Arc] { &self.outgoing_contact_requests } diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index bc6e09f3bd2fdefb2a042e1087352c4f3924df8d..64bc191b21b381d09b78e2ad61619612f292e731 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -73,6 +73,7 @@ git = { path = "../git", features = ["test-support"] } live_kit_client = { path = "../live_kit_client", features = ["test-support"] } lsp = { path = "../lsp", features = ["test-support"] } node_runtime = { path = "../node_runtime" } +notifications = { path = "../notifications", features = ["test-support"] } project = { path = "../project", features = ["test-support"] } rpc = { path = "../rpc", features = ["test-support"] } settings = { path = "../settings", features = ["test-support"] } diff --git a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql index 8eb6b52fd8b4f0ece5c64f4c45c48da4ee97fe18..7fa808b498fee75b61946c0d0e442fd20ebcc7f2 100644 --- a/crates/collab/migrations.sqlite/20221109000000_test_schema.sql +++ b/crates/collab/migrations.sqlite/20221109000000_test_schema.sql @@ -192,7 +192,7 @@ CREATE INDEX "index_followers_on_room_id" ON "followers" ("room_id"); CREATE TABLE "channels" ( "id" INTEGER PRIMARY KEY AUTOINCREMENT, "name" VARCHAR NOT NULL, - "created_at" TIMESTAMP NOT NULL DEFAULT now, + "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, "visibility" VARCHAR NOT NULL ); @@ -214,7 +214,15 @@ CREATE TABLE IF NOT EXISTS "channel_messages" ( "nonce" BLOB NOT NULL ); CREATE INDEX "index_channel_messages_on_channel_id" ON "channel_messages" ("channel_id"); -CREATE UNIQUE INDEX "index_channel_messages_on_nonce" ON "channel_messages" ("nonce"); +CREATE UNIQUE INDEX "index_channel_messages_on_sender_id_nonce" ON "channel_messages" ("sender_id", "nonce"); + +CREATE TABLE "channel_message_mentions" ( + "message_id" INTEGER NOT NULL REFERENCES channel_messages (id) ON DELETE CASCADE, + "start_offset" INTEGER NOT NULL, + "end_offset" INTEGER NOT NULL, + "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + PRIMARY KEY(message_id, start_offset) +); CREATE TABLE "channel_paths" ( "id_path" TEXT NOT NULL PRIMARY KEY, @@ -314,3 +322,26 @@ CREATE TABLE IF NOT EXISTS "observed_channel_messages" ( ); CREATE UNIQUE INDEX "index_observed_channel_messages_user_and_channel_id" ON "observed_channel_messages" ("user_id", "channel_id"); + +CREATE TABLE "notification_kinds" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "name" VARCHAR NOT NULL +); + +CREATE UNIQUE INDEX "index_notification_kinds_on_name" ON "notification_kinds" ("name"); + +CREATE TABLE "notifications" ( + "id" INTEGER PRIMARY KEY AUTOINCREMENT, + "created_at" TIMESTAMP NOT NULL default CURRENT_TIMESTAMP, + "recipient_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "kind" INTEGER NOT NULL REFERENCES notification_kinds (id), + "entity_id" INTEGER, + "content" TEXT, + "is_read" BOOLEAN NOT NULL DEFAULT FALSE, + "response" BOOLEAN +); + +CREATE INDEX + "index_notifications_on_recipient_id_is_read_kind_entity_id" + ON "notifications" + ("recipient_id", "is_read", "kind", "entity_id"); diff --git a/crates/collab/migrations/20231004130100_create_notifications.sql b/crates/collab/migrations/20231004130100_create_notifications.sql new file mode 100644 index 0000000000000000000000000000000000000000..93c282c631f3d5545593b7c71f013d8457cd088a --- /dev/null +++ b/crates/collab/migrations/20231004130100_create_notifications.sql @@ -0,0 +1,22 @@ +CREATE TABLE "notification_kinds" ( + "id" SERIAL PRIMARY KEY, + "name" VARCHAR NOT NULL +); + +CREATE UNIQUE INDEX "index_notification_kinds_on_name" ON "notification_kinds" ("name"); + +CREATE TABLE notifications ( + "id" SERIAL PRIMARY KEY, + "created_at" TIMESTAMP NOT NULL DEFAULT now(), + "recipient_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + "kind" INTEGER NOT NULL REFERENCES notification_kinds (id), + "entity_id" INTEGER, + "content" TEXT, + "is_read" BOOLEAN NOT NULL DEFAULT FALSE, + "response" BOOLEAN +); + +CREATE INDEX + "index_notifications_on_recipient_id_is_read_kind_entity_id" + ON "notifications" + ("recipient_id", "is_read", "kind", "entity_id"); diff --git a/crates/collab/migrations/20231018102700_create_mentions.sql b/crates/collab/migrations/20231018102700_create_mentions.sql new file mode 100644 index 0000000000000000000000000000000000000000..221a1748cfe16276deb4fc3dd2329983340307e7 --- /dev/null +++ b/crates/collab/migrations/20231018102700_create_mentions.sql @@ -0,0 +1,11 @@ +CREATE TABLE "channel_message_mentions" ( + "message_id" INTEGER NOT NULL REFERENCES channel_messages (id) ON DELETE CASCADE, + "start_offset" INTEGER NOT NULL, + "end_offset" INTEGER NOT NULL, + "user_id" INTEGER NOT NULL REFERENCES users (id) ON DELETE CASCADE, + PRIMARY KEY(message_id, start_offset) +); + +-- We use 'on conflict update' with this index, so it should be per-user. +CREATE UNIQUE INDEX "index_channel_messages_on_sender_id_nonce" ON "channel_messages" ("sender_id", "nonce"); +DROP INDEX "index_channel_messages_on_nonce"; diff --git a/crates/collab/src/bin/seed.rs b/crates/collab/src/bin/seed.rs index cb1594e941a0ebb1735d77c258fb9b4706880bde..88fe0a647b8924b2df1312aa8a9a3bd68b5d99f1 100644 --- a/crates/collab/src/bin/seed.rs +++ b/crates/collab/src/bin/seed.rs @@ -71,7 +71,6 @@ async fn main() { db::NewUserParams { github_login: github_user.login, github_user_id: github_user.id, - invite_count: 5, }, ) .await diff --git a/crates/collab/src/db.rs b/crates/collab/src/db.rs index 08f78c685dc8a28392733816a8e0473c2d2ca63a..5f3d0fc0c7c7775b2fb2bb69257873cf8c8684e2 100644 --- a/crates/collab/src/db.rs +++ b/crates/collab/src/db.rs @@ -13,6 +13,7 @@ use anyhow::anyhow; use collections::{BTreeMap, HashMap, HashSet}; use dashmap::DashMap; use futures::StreamExt; +use queries::channels::ChannelGraph; use rand::{prelude::StdRng, Rng, SeedableRng}; use rpc::{ proto::{self}, @@ -20,7 +21,7 @@ use rpc::{ }; use sea_orm::{ entity::prelude::*, - sea_query::{Alias, Expr, OnConflict, Query}, + sea_query::{Alias, Expr, OnConflict}, ActiveValue, Condition, ConnectionTrait, DatabaseConnection, DatabaseTransaction, DbErr, FromQueryResult, IntoActiveModel, IsolationLevel, JoinType, QueryOrder, QuerySelect, Statement, TransactionTrait, @@ -47,14 +48,14 @@ pub use ids::*; pub use sea_orm::ConnectOptions; pub use tables::user::Model as User; -use self::queries::channels::ChannelGraph; - pub struct Database { options: ConnectOptions, pool: DatabaseConnection, rooms: DashMap>>, rng: Mutex, executor: Executor, + notification_kinds_by_id: HashMap, + notification_kinds_by_name: HashMap, #[cfg(test)] runtime: Option, } @@ -69,6 +70,8 @@ impl Database { pool: sea_orm::Database::connect(options).await?, rooms: DashMap::with_capacity(16384), rng: Mutex::new(StdRng::seed_from_u64(0)), + notification_kinds_by_id: HashMap::default(), + notification_kinds_by_name: HashMap::default(), executor, #[cfg(test)] runtime: None, @@ -121,6 +124,11 @@ impl Database { Ok(new_migrations) } + pub async fn initialize_static_data(&mut self) -> Result<()> { + self.initialize_notification_kinds().await?; + Ok(()) + } + pub async fn transaction(&self, f: F) -> Result where F: Send + Fn(TransactionHandle) -> Fut, @@ -361,18 +369,9 @@ impl RoomGuard { #[derive(Clone, Debug, PartialEq, Eq)] pub enum Contact { - Accepted { - user_id: UserId, - should_notify: bool, - busy: bool, - }, - Outgoing { - user_id: UserId, - }, - Incoming { - user_id: UserId, - should_notify: bool, - }, + Accepted { user_id: UserId, busy: bool }, + Outgoing { user_id: UserId }, + Incoming { user_id: UserId }, } impl Contact { @@ -385,6 +384,15 @@ impl Contact { } } +pub type NotificationBatch = Vec<(UserId, proto::Notification)>; + +pub struct CreatedChannelMessage { + pub message_id: MessageId, + pub participant_connection_ids: Vec, + pub channel_members: Vec, + pub notifications: NotificationBatch, +} + #[derive(Clone, Debug, PartialEq, Eq, FromQueryResult, Serialize, Deserialize)] pub struct Invite { pub email_address: String, @@ -417,7 +425,6 @@ pub struct WaitlistSummary { pub struct NewUserParams { pub github_login: String, pub github_user_id: i32, - pub invite_count: i32, } #[derive(Debug)] diff --git a/crates/collab/src/db/ids.rs b/crates/collab/src/db/ids.rs index f0de4c255edc7b13eb27656ceaccefb3c5e26c02..433444de67e773a8114109326832e02722f3fe5e 100644 --- a/crates/collab/src/db/ids.rs +++ b/crates/collab/src/db/ids.rs @@ -81,6 +81,8 @@ id_type!(SignupId); id_type!(UserId); id_type!(ChannelBufferCollaboratorId); id_type!(FlagId); +id_type!(NotificationId); +id_type!(NotificationKindId); #[derive(Eq, PartialEq, Copy, Clone, Debug, EnumIter, DeriveActiveEnum, Default)] #[sea_orm(rs_type = "String", db_type = "String(None)")] diff --git a/crates/collab/src/db/queries.rs b/crates/collab/src/db/queries.rs index 80bd8704b27704361241a93c56f5945ef51ef3cc..629e26f1a9e2ac1479f80984d2f9ae3efe7e9ab7 100644 --- a/crates/collab/src/db/queries.rs +++ b/crates/collab/src/db/queries.rs @@ -5,6 +5,7 @@ pub mod buffers; pub mod channels; pub mod contacts; pub mod messages; +pub mod notifications; pub mod projects; pub mod rooms; pub mod servers; diff --git a/crates/collab/src/db/queries/access_tokens.rs b/crates/collab/src/db/queries/access_tokens.rs index def9428a2bedc0a8635364bc48aeb6fe419a2f11..589b6483dfceb5df285ac67b03edbee493e4705b 100644 --- a/crates/collab/src/db/queries/access_tokens.rs +++ b/crates/collab/src/db/queries/access_tokens.rs @@ -1,4 +1,5 @@ use super::*; +use sea_orm::sea_query::Query; impl Database { pub async fn create_access_token( diff --git a/crates/collab/src/db/queries/channels.rs b/crates/collab/src/db/queries/channels.rs index ee989b2ea068ff362adc635e7616b1c25a2de573..4ee7625afdf950a04ffad74516d384a3ee299bd3 100644 --- a/crates/collab/src/db/queries/channels.rs +++ b/crates/collab/src/db/queries/channels.rs @@ -269,13 +269,18 @@ impl Database { &self, channel_id: ChannelId, invitee_id: UserId, - admin_id: UserId, + inviter_id: UserId, role: ChannelRole, - ) -> Result<()> { + ) -> Result { self.transaction(move |tx| async move { - self.check_user_is_channel_admin(channel_id, admin_id, &*tx) + self.check_user_is_channel_admin(channel_id, inviter_id, &*tx) .await?; + let channel = channel::Entity::find_by_id(channel_id) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("no such channel"))?; + channel_member::ActiveModel { id: ActiveValue::NotSet, channel_id: ActiveValue::Set(channel_id), @@ -286,7 +291,20 @@ impl Database { .insert(&*tx) .await?; - Ok(()) + Ok(self + .create_notification( + invitee_id, + rpc::Notification::ChannelInvitation { + channel_id: channel_id.to_proto(), + channel_name: channel.name, + inviter_id: inviter_id.to_proto(), + }, + true, + &*tx, + ) + .await? + .into_iter() + .collect()) }) .await } @@ -333,7 +351,7 @@ impl Database { channel_id: ChannelId, user_id: UserId, accept: bool, - ) -> Result<()> { + ) -> Result { self.transaction(move |tx| async move { let rows_affected = if accept { channel_member::Entity::update_many() @@ -351,21 +369,36 @@ impl Database { .await? .rows_affected } else { - channel_member::ActiveModel { - channel_id: ActiveValue::Unchanged(channel_id), - user_id: ActiveValue::Unchanged(user_id), - ..Default::default() - } - .delete(&*tx) - .await? - .rows_affected + channel_member::Entity::delete_many() + .filter( + channel_member::Column::ChannelId + .eq(channel_id) + .and(channel_member::Column::UserId.eq(user_id)) + .and(channel_member::Column::Accepted.eq(false)), + ) + .exec(&*tx) + .await? + .rows_affected }; if rows_affected == 0 { Err(anyhow!("no such invitation"))?; } - Ok(()) + Ok(self + .mark_notification_as_read_with_response( + user_id, + &rpc::Notification::ChannelInvitation { + channel_id: channel_id.to_proto(), + channel_name: Default::default(), + inviter_id: Default::default(), + }, + accept, + &*tx, + ) + .await? + .into_iter() + .collect()) }) .await } @@ -375,7 +408,7 @@ impl Database { channel_id: ChannelId, member_id: UserId, admin_id: UserId, - ) -> Result<()> { + ) -> Result> { self.transaction(|tx| async move { self.check_user_is_channel_admin(channel_id, admin_id, &*tx) .await?; @@ -393,7 +426,17 @@ impl Database { Err(anyhow!("no such member"))?; } - Ok(()) + Ok(self + .remove_notification( + member_id, + rpc::Notification::ChannelInvitation { + channel_id: channel_id.to_proto(), + channel_name: Default::default(), + inviter_id: Default::default(), + }, + &*tx, + ) + .await?) }) .await } @@ -667,10 +710,11 @@ impl Database { pub async fn get_channel_participant_details( &self, channel_id: ChannelId, - admin_id: UserId, + user_id: UserId, ) -> Result> { self.transaction(|tx| async move { - self.check_user_is_channel_admin(channel_id, admin_id, &*tx) + let user_role = self + .check_user_is_channel_member(channel_id, user_id, &*tx) .await?; let channel_visibility = channel::Entity::find() @@ -753,10 +797,26 @@ impl Database { Ok(user_details .into_iter() - .map(|(user_id, details)| proto::ChannelMember { - user_id: user_id.to_proto(), - kind: details.kind.into(), - role: details.channel_role.into(), + .filter_map(|(user_id, mut details)| { + // If the user is not an admin, don't give them as much + // information about the other members. + if user_role != ChannelRole::Admin { + if details.kind == Kind::Invitee + || details.channel_role == ChannelRole::Banned + { + return None; + } + + if details.channel_role == ChannelRole::Admin { + details.channel_role = ChannelRole::Member; + } + } + + Some(proto::ChannelMember { + user_id: user_id.to_proto(), + kind: details.kind.into(), + role: details.channel_role.into(), + }) }) .collect()) }) @@ -806,9 +866,10 @@ impl Database { channel_id: ChannelId, user_id: UserId, tx: &DatabaseTransaction, - ) -> Result<()> { - match self.channel_role_for_user(channel_id, user_id, tx).await? { - Some(ChannelRole::Admin) | Some(ChannelRole::Member) => Ok(()), + ) -> Result { + let channel_role = self.channel_role_for_user(channel_id, user_id, tx).await?; + match channel_role { + Some(ChannelRole::Admin) | Some(ChannelRole::Member) => Ok(channel_role.unwrap()), Some(ChannelRole::Banned) | Some(ChannelRole::Guest) | None => Err(anyhow!( "user is not a channel member or channel does not exist" ))?, diff --git a/crates/collab/src/db/queries/contacts.rs b/crates/collab/src/db/queries/contacts.rs index 2171f1a6bf87354b8017fd35b47d77bba1e25af0..f31f1addbd2b4a210abfa9810f062585a7b656e4 100644 --- a/crates/collab/src/db/queries/contacts.rs +++ b/crates/collab/src/db/queries/contacts.rs @@ -8,7 +8,6 @@ impl Database { user_id_b: UserId, a_to_b: bool, accepted: bool, - should_notify: bool, user_a_busy: bool, user_b_busy: bool, } @@ -53,7 +52,6 @@ impl Database { if db_contact.accepted { contacts.push(Contact::Accepted { user_id: db_contact.user_id_b, - should_notify: db_contact.should_notify && db_contact.a_to_b, busy: db_contact.user_b_busy, }); } else if db_contact.a_to_b { @@ -63,19 +61,16 @@ impl Database { } else { contacts.push(Contact::Incoming { user_id: db_contact.user_id_b, - should_notify: db_contact.should_notify, }); } } else if db_contact.accepted { contacts.push(Contact::Accepted { user_id: db_contact.user_id_a, - should_notify: db_contact.should_notify && !db_contact.a_to_b, busy: db_contact.user_a_busy, }); } else if db_contact.a_to_b { contacts.push(Contact::Incoming { user_id: db_contact.user_id_a, - should_notify: db_contact.should_notify, }); } else { contacts.push(Contact::Outgoing { @@ -124,7 +119,11 @@ impl Database { .await } - pub async fn send_contact_request(&self, sender_id: UserId, receiver_id: UserId) -> Result<()> { + pub async fn send_contact_request( + &self, + sender_id: UserId, + receiver_id: UserId, + ) -> Result { self.transaction(|tx| async move { let (id_a, id_b, a_to_b) = if sender_id < receiver_id { (sender_id, receiver_id, true) @@ -161,11 +160,22 @@ impl Database { .exec_without_returning(&*tx) .await?; - if rows_affected == 1 { - Ok(()) - } else { - Err(anyhow!("contact already requested"))? + if rows_affected == 0 { + Err(anyhow!("contact already requested"))?; } + + Ok(self + .create_notification( + receiver_id, + rpc::Notification::ContactRequest { + sender_id: sender_id.to_proto(), + }, + true, + &*tx, + ) + .await? + .into_iter() + .collect()) }) .await } @@ -179,7 +189,11 @@ impl Database { /// /// * `requester_id` - The user that initiates this request /// * `responder_id` - The user that will be removed - pub async fn remove_contact(&self, requester_id: UserId, responder_id: UserId) -> Result { + pub async fn remove_contact( + &self, + requester_id: UserId, + responder_id: UserId, + ) -> Result<(bool, Option)> { self.transaction(|tx| async move { let (id_a, id_b) = if responder_id < requester_id { (responder_id, requester_id) @@ -198,7 +212,21 @@ impl Database { .ok_or_else(|| anyhow!("no such contact"))?; contact::Entity::delete_by_id(contact.id).exec(&*tx).await?; - Ok(contact.accepted) + + let mut deleted_notification_id = None; + if !contact.accepted { + deleted_notification_id = self + .remove_notification( + responder_id, + rpc::Notification::ContactRequest { + sender_id: requester_id.to_proto(), + }, + &*tx, + ) + .await?; + } + + Ok((contact.accepted, deleted_notification_id)) }) .await } @@ -249,7 +277,7 @@ impl Database { responder_id: UserId, requester_id: UserId, accept: bool, - ) -> Result<()> { + ) -> Result { self.transaction(|tx| async move { let (id_a, id_b, a_to_b) = if responder_id < requester_id { (responder_id, requester_id, false) @@ -287,11 +315,38 @@ impl Database { result.rows_affected }; - if rows_affected == 1 { - Ok(()) - } else { + if rows_affected == 0 { Err(anyhow!("no such contact request"))? } + + let mut notifications = Vec::new(); + notifications.extend( + self.mark_notification_as_read_with_response( + responder_id, + &rpc::Notification::ContactRequest { + sender_id: requester_id.to_proto(), + }, + accept, + &*tx, + ) + .await?, + ); + + if accept { + notifications.extend( + self.create_notification( + requester_id, + rpc::Notification::ContactRequestAccepted { + responder_id: responder_id.to_proto(), + }, + true, + &*tx, + ) + .await?, + ); + } + + Ok(notifications) }) .await } diff --git a/crates/collab/src/db/queries/messages.rs b/crates/collab/src/db/queries/messages.rs index 06e954103d18c0916b2dc76ba474a06283a93cf9..d406cbb091389cdcc6073873752ffa4e14000fad 100644 --- a/crates/collab/src/db/queries/messages.rs +++ b/crates/collab/src/db/queries/messages.rs @@ -1,4 +1,7 @@ use super::*; +use futures::Stream; +use rpc::Notification; +use sea_orm::TryInsertResult; use time::OffsetDateTime; impl Database { @@ -87,43 +90,118 @@ impl Database { condition = condition.add(channel_message::Column::Id.lt(before_message_id)); } - let mut rows = channel_message::Entity::find() + let rows = channel_message::Entity::find() .filter(condition) .order_by_desc(channel_message::Column::Id) .limit(count as u64) .stream(&*tx) .await?; - let mut messages = Vec::new(); - while let Some(row) = rows.next().await { - let row = row?; - let nonce = row.nonce.as_u64_pair(); - messages.push(proto::ChannelMessage { - id: row.id.to_proto(), - sender_id: row.sender_id.to_proto(), - body: row.body, - timestamp: row.sent_at.assume_utc().unix_timestamp() as u64, - nonce: Some(proto::Nonce { - upper_half: nonce.0, - lower_half: nonce.1, + self.load_channel_messages(rows, &*tx).await + }) + .await + } + + pub async fn get_channel_messages_by_id( + &self, + user_id: UserId, + message_ids: &[MessageId], + ) -> Result> { + self.transaction(|tx| async move { + let rows = channel_message::Entity::find() + .filter(channel_message::Column::Id.is_in(message_ids.iter().copied())) + .order_by_desc(channel_message::Column::Id) + .stream(&*tx) + .await?; + + let mut channel_ids = HashSet::::default(); + let messages = self + .load_channel_messages( + rows.map(|row| { + row.map(|row| { + channel_ids.insert(row.channel_id); + row + }) }), - }); + &*tx, + ) + .await?; + + for channel_id in channel_ids { + self.check_user_is_channel_member(channel_id, user_id, &*tx) + .await?; } - drop(rows); - messages.reverse(); + Ok(messages) }) .await } + async fn load_channel_messages( + &self, + mut rows: impl Send + Unpin + Stream>, + tx: &DatabaseTransaction, + ) -> Result> { + let mut messages = Vec::new(); + while let Some(row) = rows.next().await { + let row = row?; + let nonce = row.nonce.as_u64_pair(); + messages.push(proto::ChannelMessage { + id: row.id.to_proto(), + sender_id: row.sender_id.to_proto(), + body: row.body, + timestamp: row.sent_at.assume_utc().unix_timestamp() as u64, + mentions: vec![], + nonce: Some(proto::Nonce { + upper_half: nonce.0, + lower_half: nonce.1, + }), + }); + } + drop(rows); + messages.reverse(); + + let mut mentions = channel_message_mention::Entity::find() + .filter(channel_message_mention::Column::MessageId.is_in(messages.iter().map(|m| m.id))) + .order_by_asc(channel_message_mention::Column::MessageId) + .order_by_asc(channel_message_mention::Column::StartOffset) + .stream(&*tx) + .await?; + + let mut message_ix = 0; + while let Some(mention) = mentions.next().await { + let mention = mention?; + let message_id = mention.message_id.to_proto(); + while let Some(message) = messages.get_mut(message_ix) { + if message.id < message_id { + message_ix += 1; + } else { + if message.id == message_id { + message.mentions.push(proto::ChatMention { + range: Some(proto::Range { + start: mention.start_offset as u64, + end: mention.end_offset as u64, + }), + user_id: mention.user_id.to_proto(), + }); + } + break; + } + } + } + + Ok(messages) + } + pub async fn create_channel_message( &self, channel_id: ChannelId, user_id: UserId, body: &str, + mentions: &[proto::ChatMention], timestamp: OffsetDateTime, nonce: u128, - ) -> Result<(MessageId, Vec, Vec)> { + ) -> Result { self.transaction(|tx| async move { let mut rows = channel_chat_participant::Entity::find() .filter(channel_chat_participant::Column::ChannelId.eq(channel_id)) @@ -150,7 +228,7 @@ impl Database { let timestamp = timestamp.to_offset(time::UtcOffset::UTC); let timestamp = time::PrimitiveDateTime::new(timestamp.date(), timestamp.time()); - let message = channel_message::Entity::insert(channel_message::ActiveModel { + let result = channel_message::Entity::insert(channel_message::ActiveModel { channel_id: ActiveValue::Set(channel_id), sender_id: ActiveValue::Set(user_id), body: ActiveValue::Set(body.to_string()), @@ -159,37 +237,87 @@ impl Database { id: ActiveValue::NotSet, }) .on_conflict( - OnConflict::column(channel_message::Column::Nonce) - .update_column(channel_message::Column::Nonce) - .to_owned(), + OnConflict::columns([ + channel_message::Column::SenderId, + channel_message::Column::Nonce, + ]) + .do_nothing() + .to_owned(), ) + .do_nothing() .exec(&*tx) .await?; - #[derive(Debug, Clone, Copy, EnumIter, DeriveColumn)] - enum QueryConnectionId { - ConnectionId, - } + let message_id; + let mut notifications = Vec::new(); + match result { + TryInsertResult::Inserted(result) => { + message_id = result.last_insert_id; + let mentioned_user_ids = + mentions.iter().map(|m| m.user_id).collect::>(); + let mentions = mentions + .iter() + .filter_map(|mention| { + let range = mention.range.as_ref()?; + if !body.is_char_boundary(range.start as usize) + || !body.is_char_boundary(range.end as usize) + { + return None; + } + Some(channel_message_mention::ActiveModel { + message_id: ActiveValue::Set(message_id), + start_offset: ActiveValue::Set(range.start as i32), + end_offset: ActiveValue::Set(range.end as i32), + user_id: ActiveValue::Set(UserId::from_proto(mention.user_id)), + }) + }) + .collect::>(); + if !mentions.is_empty() { + channel_message_mention::Entity::insert_many(mentions) + .exec(&*tx) + .await?; + } - // Observe this message for the sender - self.observe_channel_message_internal( - channel_id, - user_id, - message.last_insert_id, - &*tx, - ) - .await?; + for mentioned_user in mentioned_user_ids { + notifications.extend( + self.create_notification( + UserId::from_proto(mentioned_user), + rpc::Notification::ChannelMessageMention { + message_id: message_id.to_proto(), + sender_id: user_id.to_proto(), + channel_id: channel_id.to_proto(), + }, + false, + &*tx, + ) + .await?, + ); + } + + self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx) + .await?; + } + _ => { + message_id = channel_message::Entity::find() + .filter(channel_message::Column::Nonce.eq(Uuid::from_u128(nonce))) + .one(&*tx) + .await? + .ok_or_else(|| anyhow!("failed to insert message"))? + .id; + } + } let mut channel_members = self .get_channel_participants_internal(channel_id, &*tx) .await?; channel_members.retain(|member| !participant_user_ids.contains(member)); - Ok(( - message.last_insert_id, + Ok(CreatedChannelMessage { + message_id, participant_connection_ids, channel_members, - )) + notifications, + }) }) .await } @@ -199,11 +327,24 @@ impl Database { channel_id: ChannelId, user_id: UserId, message_id: MessageId, - ) -> Result<()> { + ) -> Result { self.transaction(|tx| async move { self.observe_channel_message_internal(channel_id, user_id, message_id, &*tx) .await?; - Ok(()) + let mut batch = NotificationBatch::default(); + batch.extend( + self.mark_notification_as_read( + user_id, + &Notification::ChannelMessageMention { + message_id: message_id.to_proto(), + sender_id: Default::default(), + channel_id: Default::default(), + }, + &*tx, + ) + .await?, + ); + Ok(batch) }) .await } diff --git a/crates/collab/src/db/queries/notifications.rs b/crates/collab/src/db/queries/notifications.rs new file mode 100644 index 0000000000000000000000000000000000000000..6f2511c23e7cd383760aa29ec62a65ca30c636d8 --- /dev/null +++ b/crates/collab/src/db/queries/notifications.rs @@ -0,0 +1,262 @@ +use super::*; +use rpc::Notification; + +impl Database { + pub async fn initialize_notification_kinds(&mut self) -> Result<()> { + notification_kind::Entity::insert_many(Notification::all_variant_names().iter().map( + |kind| notification_kind::ActiveModel { + name: ActiveValue::Set(kind.to_string()), + ..Default::default() + }, + )) + .on_conflict(OnConflict::new().do_nothing().to_owned()) + .exec_without_returning(&self.pool) + .await?; + + let mut rows = notification_kind::Entity::find().stream(&self.pool).await?; + while let Some(row) = rows.next().await { + let row = row?; + self.notification_kinds_by_name.insert(row.name, row.id); + } + + for name in Notification::all_variant_names() { + if let Some(id) = self.notification_kinds_by_name.get(*name).copied() { + self.notification_kinds_by_id.insert(id, name); + } + } + + Ok(()) + } + + pub async fn get_notifications( + &self, + recipient_id: UserId, + limit: usize, + before_id: Option, + ) -> Result> { + self.transaction(|tx| async move { + let mut result = Vec::new(); + let mut condition = + Condition::all().add(notification::Column::RecipientId.eq(recipient_id)); + + if let Some(before_id) = before_id { + condition = condition.add(notification::Column::Id.lt(before_id)); + } + + let mut rows = notification::Entity::find() + .filter(condition) + .order_by_desc(notification::Column::Id) + .limit(limit as u64) + .stream(&*tx) + .await?; + while let Some(row) = rows.next().await { + let row = row?; + let kind = row.kind; + if let Some(proto) = model_to_proto(self, row) { + result.push(proto); + } else { + log::warn!("unknown notification kind {:?}", kind); + } + } + result.reverse(); + Ok(result) + }) + .await + } + + /// Create a notification. If `avoid_duplicates` is set to true, then avoid + /// creating a new notification if the given recipient already has an + /// unread notification with the given kind and entity id. + pub async fn create_notification( + &self, + recipient_id: UserId, + notification: Notification, + avoid_duplicates: bool, + tx: &DatabaseTransaction, + ) -> Result> { + if avoid_duplicates { + if self + .find_notification(recipient_id, ¬ification, tx) + .await? + .is_some() + { + return Ok(None); + } + } + + let proto = notification.to_proto(); + let kind = notification_kind_from_proto(self, &proto)?; + let model = notification::ActiveModel { + recipient_id: ActiveValue::Set(recipient_id), + kind: ActiveValue::Set(kind), + entity_id: ActiveValue::Set(proto.entity_id.map(|id| id as i32)), + content: ActiveValue::Set(proto.content.clone()), + ..Default::default() + } + .save(&*tx) + .await?; + + Ok(Some(( + recipient_id, + proto::Notification { + id: model.id.as_ref().to_proto(), + kind: proto.kind, + timestamp: model.created_at.as_ref().assume_utc().unix_timestamp() as u64, + is_read: false, + response: None, + content: proto.content, + entity_id: proto.entity_id, + }, + ))) + } + + /// Remove an unread notification with the given recipient, kind and + /// entity id. + pub async fn remove_notification( + &self, + recipient_id: UserId, + notification: Notification, + tx: &DatabaseTransaction, + ) -> Result> { + let id = self + .find_notification(recipient_id, ¬ification, tx) + .await?; + if let Some(id) = id { + notification::Entity::delete_by_id(id).exec(tx).await?; + } + Ok(id) + } + + /// Populate the response for the notification with the given kind and + /// entity id. + pub async fn mark_notification_as_read_with_response( + &self, + recipient_id: UserId, + notification: &Notification, + response: bool, + tx: &DatabaseTransaction, + ) -> Result> { + self.mark_notification_as_read_internal(recipient_id, notification, Some(response), tx) + .await + } + + pub async fn mark_notification_as_read( + &self, + recipient_id: UserId, + notification: &Notification, + tx: &DatabaseTransaction, + ) -> Result> { + self.mark_notification_as_read_internal(recipient_id, notification, None, tx) + .await + } + + pub async fn mark_notification_as_read_by_id( + &self, + recipient_id: UserId, + notification_id: NotificationId, + ) -> Result { + self.transaction(|tx| async move { + let row = notification::Entity::update(notification::ActiveModel { + id: ActiveValue::Unchanged(notification_id), + recipient_id: ActiveValue::Unchanged(recipient_id), + is_read: ActiveValue::Set(true), + ..Default::default() + }) + .exec(&*tx) + .await?; + Ok(model_to_proto(self, row) + .map(|notification| (recipient_id, notification)) + .into_iter() + .collect()) + }) + .await + } + + async fn mark_notification_as_read_internal( + &self, + recipient_id: UserId, + notification: &Notification, + response: Option, + tx: &DatabaseTransaction, + ) -> Result> { + if let Some(id) = self + .find_notification(recipient_id, notification, &*tx) + .await? + { + let row = notification::Entity::update(notification::ActiveModel { + id: ActiveValue::Unchanged(id), + recipient_id: ActiveValue::Unchanged(recipient_id), + is_read: ActiveValue::Set(true), + response: if let Some(response) = response { + ActiveValue::Set(Some(response)) + } else { + ActiveValue::NotSet + }, + ..Default::default() + }) + .exec(tx) + .await?; + Ok(model_to_proto(self, row).map(|notification| (recipient_id, notification))) + } else { + Ok(None) + } + } + + /// Find an unread notification by its recipient, kind and entity id. + async fn find_notification( + &self, + recipient_id: UserId, + notification: &Notification, + tx: &DatabaseTransaction, + ) -> Result> { + let proto = notification.to_proto(); + let kind = notification_kind_from_proto(self, &proto)?; + + #[derive(Copy, Clone, Debug, EnumIter, DeriveColumn)] + enum QueryIds { + Id, + } + + Ok(notification::Entity::find() + .select_only() + .column(notification::Column::Id) + .filter( + Condition::all() + .add(notification::Column::RecipientId.eq(recipient_id)) + .add(notification::Column::IsRead.eq(false)) + .add(notification::Column::Kind.eq(kind)) + .add(if proto.entity_id.is_some() { + notification::Column::EntityId.eq(proto.entity_id) + } else { + notification::Column::EntityId.is_null() + }), + ) + .into_values::<_, QueryIds>() + .one(&*tx) + .await?) + } +} + +fn model_to_proto(this: &Database, row: notification::Model) -> Option { + let kind = this.notification_kinds_by_id.get(&row.kind)?; + Some(proto::Notification { + id: row.id.to_proto(), + kind: kind.to_string(), + timestamp: row.created_at.assume_utc().unix_timestamp() as u64, + is_read: row.is_read, + response: row.response, + content: row.content, + entity_id: row.entity_id.map(|id| id as u64), + }) +} + +fn notification_kind_from_proto( + this: &Database, + proto: &proto::Notification, +) -> Result { + Ok(this + .notification_kinds_by_name + .get(&proto.kind) + .copied() + .ok_or_else(|| anyhow!("invalid notification kind {:?}", proto.kind))?) +} diff --git a/crates/collab/src/db/tables.rs b/crates/collab/src/db/tables.rs index e19391da7dd513970b0fa593d7977fa7689c0510..0acb266d9dab95e0205fb4cd2bb522a7167e8b86 100644 --- a/crates/collab/src/db/tables.rs +++ b/crates/collab/src/db/tables.rs @@ -7,11 +7,14 @@ pub mod channel_buffer_collaborator; pub mod channel_chat_participant; pub mod channel_member; pub mod channel_message; +pub mod channel_message_mention; pub mod channel_path; pub mod contact; pub mod feature_flag; pub mod follower; pub mod language_server; +pub mod notification; +pub mod notification_kind; pub mod observed_buffer_edits; pub mod observed_channel_messages; pub mod project; diff --git a/crates/collab/src/db/tables/channel_message_mention.rs b/crates/collab/src/db/tables/channel_message_mention.rs new file mode 100644 index 0000000000000000000000000000000000000000..6155b057f0cf8862cb26f6efff30669d59592eb8 --- /dev/null +++ b/crates/collab/src/db/tables/channel_message_mention.rs @@ -0,0 +1,43 @@ +use crate::db::{MessageId, UserId}; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "channel_message_mentions")] +pub struct Model { + #[sea_orm(primary_key)] + pub message_id: MessageId, + #[sea_orm(primary_key)] + pub start_offset: i32, + pub end_offset: i32, + pub user_id: UserId, +} + +impl ActiveModelBehavior for ActiveModel {} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::channel_message::Entity", + from = "Column::MessageId", + to = "super::channel_message::Column::Id" + )] + Message, + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::UserId", + to = "super::user::Column::Id" + )] + MentionedUser, +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::Message.def() + } +} + +impl Related for Entity { + fn to() -> RelationDef { + Relation::MentionedUser.def() + } +} diff --git a/crates/collab/src/db/tables/notification.rs b/crates/collab/src/db/tables/notification.rs new file mode 100644 index 0000000000000000000000000000000000000000..3105198fa21764351b4a2343258e91055b6a8641 --- /dev/null +++ b/crates/collab/src/db/tables/notification.rs @@ -0,0 +1,29 @@ +use crate::db::{NotificationId, NotificationKindId, UserId}; +use sea_orm::entity::prelude::*; +use time::PrimitiveDateTime; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "notifications")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: NotificationId, + pub created_at: PrimitiveDateTime, + pub recipient_id: UserId, + pub kind: NotificationKindId, + pub entity_id: Option, + pub content: String, + pub is_read: bool, + pub response: Option, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation { + #[sea_orm( + belongs_to = "super::user::Entity", + from = "Column::RecipientId", + to = "super::user::Column::Id" + )] + Recipient, +} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tables/notification_kind.rs b/crates/collab/src/db/tables/notification_kind.rs new file mode 100644 index 0000000000000000000000000000000000000000..865b5da04bad2a7068aa6a2fd8e8adbe7586fd08 --- /dev/null +++ b/crates/collab/src/db/tables/notification_kind.rs @@ -0,0 +1,15 @@ +use crate::db::NotificationKindId; +use sea_orm::entity::prelude::*; + +#[derive(Clone, Debug, PartialEq, Eq, DeriveEntityModel)] +#[sea_orm(table_name = "notification_kinds")] +pub struct Model { + #[sea_orm(primary_key)] + pub id: NotificationKindId, + pub name: String, +} + +#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)] +pub enum Relation {} + +impl ActiveModelBehavior for ActiveModel {} diff --git a/crates/collab/src/db/tests.rs b/crates/collab/src/db/tests.rs index 99a605106eafab55f50b2b37a0a5a027b1cb9457..83154b9a0dfbbdb63b13e6791bb8a150fadc434a 100644 --- a/crates/collab/src/db/tests.rs +++ b/crates/collab/src/db/tests.rs @@ -10,7 +10,10 @@ use parking_lot::Mutex; use rpc::proto::ChannelEdge; use sea_orm::ConnectionTrait; use sqlx::migrate::MigrateDatabase; -use std::sync::Arc; +use std::sync::{ + atomic::{AtomicI32, Ordering::SeqCst}, + Arc, +}; const TEST_RELEASE_CHANNEL: &'static str = "test"; @@ -31,7 +34,7 @@ impl TestDb { let mut db = runtime.block_on(async { let mut options = ConnectOptions::new(url); options.max_connections(5); - let db = Database::new(options, Executor::Deterministic(background)) + let mut db = Database::new(options, Executor::Deterministic(background)) .await .unwrap(); let sql = include_str!(concat!( @@ -45,6 +48,7 @@ impl TestDb { )) .await .unwrap(); + db.initialize_notification_kinds().await.unwrap(); db }); @@ -79,11 +83,12 @@ impl TestDb { options .max_connections(5) .idle_timeout(Duration::from_secs(0)); - let db = Database::new(options, Executor::Deterministic(background)) + let mut db = Database::new(options, Executor::Deterministic(background)) .await .unwrap(); let migrations_path = concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"); db.migrate(Path::new(migrations_path), false).await.unwrap(); + db.initialize_notification_kinds().await.unwrap(); db }); @@ -172,3 +177,19 @@ fn graph(channels: &[(ChannelId, &'static str)], edges: &[(ChannelId, ChannelId) graph } + +static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5); + +async fn new_test_user(db: &Arc, email: &str) -> UserId { + db.create_user( + email, + false, + NewUserParams { + github_login: email[0..email.find("@").unwrap()].to_string(), + github_user_id: GITHUB_USER_ID.fetch_add(1, SeqCst), + }, + ) + .await + .unwrap() + .user_id +} diff --git a/crates/collab/src/db/tests/buffer_tests.rs b/crates/collab/src/db/tests/buffer_tests.rs index 51ba9bf655221a5c611ad3fa023631f46151f144..222514da0b32d9b8bd551bcfa60901b67bb590a0 100644 --- a/crates/collab/src/db/tests/buffer_tests.rs +++ b/crates/collab/src/db/tests/buffer_tests.rs @@ -17,7 +17,6 @@ async fn test_channel_buffers(db: &Arc) { NewUserParams { github_login: "user_a".into(), github_user_id: 101, - invite_count: 0, }, ) .await @@ -30,7 +29,6 @@ async fn test_channel_buffers(db: &Arc) { NewUserParams { github_login: "user_b".into(), github_user_id: 102, - invite_count: 0, }, ) .await @@ -45,7 +43,6 @@ async fn test_channel_buffers(db: &Arc) { NewUserParams { github_login: "user_c".into(), github_user_id: 102, - invite_count: 0, }, ) .await @@ -178,7 +175,6 @@ async fn test_channel_buffers_last_operations(db: &Database) { NewUserParams { github_login: "user_a".into(), github_user_id: 101, - invite_count: 0, }, ) .await @@ -191,7 +187,6 @@ async fn test_channel_buffers_last_operations(db: &Database) { NewUserParams { github_login: "user_b".into(), github_user_id: 102, - invite_count: 0, }, ) .await diff --git a/crates/collab/src/db/tests/channel_tests.rs b/crates/collab/src/db/tests/channel_tests.rs index 40842aff5c3bbd90fa7fc1e38297096e6c2e8bcc..556437e45b6e5471b93619e0d65f1d4430d8dedf 100644 --- a/crates/collab/src/db/tests/channel_tests.rs +++ b/crates/collab/src/db/tests/channel_tests.rs @@ -1,21 +1,17 @@ -use collections::{HashMap, HashSet}; -use rpc::{ - proto::{self}, - ConnectionId, -}; - use crate::{ db::{ queries::channels::ChannelGraph, - tests::{graph, TEST_RELEASE_CHANNEL}, - ChannelId, ChannelRole, Database, NewUserParams, RoomId, UserId, + tests::{graph, new_test_user, TEST_RELEASE_CHANNEL}, + ChannelId, ChannelRole, Database, NewUserParams, RoomId, }, test_both_dbs, }; -use std::sync::{ - atomic::{AtomicI32, Ordering}, - Arc, +use collections::{HashMap, HashSet}; +use rpc::{ + proto::{self}, + ConnectionId, }; +use std::sync::Arc; test_both_dbs!(test_channels, test_channels_postgres, test_channels_sqlite); @@ -27,7 +23,6 @@ async fn test_channels(db: &Arc) { NewUserParams { github_login: "user1".into(), github_user_id: 5, - invite_count: 0, }, ) .await @@ -41,7 +36,6 @@ async fn test_channels(db: &Arc) { NewUserParams { github_login: "user2".into(), github_user_id: 6, - invite_count: 0, }, ) .await @@ -186,7 +180,6 @@ async fn test_joining_channels(db: &Arc) { NewUserParams { github_login: "user1".into(), github_user_id: 5, - invite_count: 0, }, ) .await @@ -199,7 +192,6 @@ async fn test_joining_channels(db: &Arc) { NewUserParams { github_login: "user2".into(), github_user_id: 6, - invite_count: 0, }, ) .await @@ -354,7 +346,6 @@ async fn test_channel_renames(db: &Arc) { NewUserParams { github_login: "user1".into(), github_user_id: 5, - invite_count: 0, }, ) .await @@ -368,7 +359,6 @@ async fn test_channel_renames(db: &Arc) { NewUserParams { github_login: "user2".into(), github_user_id: 6, - invite_count: 0, }, ) .await @@ -409,7 +399,6 @@ async fn test_db_channel_moving(db: &Arc) { NewUserParams { github_login: "user1".into(), github_user_id: 5, - invite_count: 0, }, ) .await @@ -767,7 +756,6 @@ async fn test_db_channel_moving_bugs(db: &Arc) { NewUserParams { github_login: "user1".into(), github_user_id: 5, - invite_count: 0, }, ) .await @@ -1113,20 +1101,3 @@ fn assert_dag(actual: ChannelGraph, expected: &[(ChannelId, Option)]) pretty_assertions::assert_eq!(actual_map, expected_map) } - -static GITHUB_USER_ID: AtomicI32 = AtomicI32::new(5); - -async fn new_test_user(db: &Arc, email: &str) -> UserId { - db.create_user( - email, - false, - NewUserParams { - github_login: email[0..email.find("@").unwrap()].to_string(), - github_user_id: GITHUB_USER_ID.fetch_add(1, Ordering::SeqCst), - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id -} diff --git a/crates/collab/src/db/tests/db_tests.rs b/crates/collab/src/db/tests/db_tests.rs index 1520e081c07ead1afc376f84d2e12918fef40db2..c4b82f8cecfbf8ba0dff28e7749e09048a9a3539 100644 --- a/crates/collab/src/db/tests/db_tests.rs +++ b/crates/collab/src/db/tests/db_tests.rs @@ -22,7 +22,6 @@ async fn test_get_users(db: &Arc) { NewUserParams { github_login: format!("user{i}"), github_user_id: i, - invite_count: 0, }, ) .await @@ -88,7 +87,6 @@ async fn test_get_or_create_user_by_github_account(db: &Arc) { NewUserParams { github_login: "login1".into(), github_user_id: 101, - invite_count: 0, }, ) .await @@ -101,7 +99,6 @@ async fn test_get_or_create_user_by_github_account(db: &Arc) { NewUserParams { github_login: "login2".into(), github_user_id: 102, - invite_count: 0, }, ) .await @@ -156,7 +153,6 @@ async fn test_create_access_tokens(db: &Arc) { NewUserParams { github_login: "u1".into(), github_user_id: 1, - invite_count: 0, }, ) .await @@ -238,7 +234,6 @@ async fn test_add_contacts(db: &Arc) { NewUserParams { github_login: format!("user{i}"), github_user_id: i, - invite_count: 0, }, ) .await @@ -264,10 +259,7 @@ async fn test_add_contacts(db: &Arc) { ); assert_eq!( db.get_contacts(user_2).await.unwrap(), - &[Contact::Incoming { - user_id: user_1, - should_notify: true - }] + &[Contact::Incoming { user_id: user_1 }] ); // User 2 dismisses the contact request notification without accepting or rejecting. @@ -280,10 +272,7 @@ async fn test_add_contacts(db: &Arc) { .unwrap(); assert_eq!( db.get_contacts(user_2).await.unwrap(), - &[Contact::Incoming { - user_id: user_1, - should_notify: false - }] + &[Contact::Incoming { user_id: user_1 }] ); // User can't accept their own contact request @@ -299,7 +288,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_1).await.unwrap(), &[Contact::Accepted { user_id: user_2, - should_notify: true, busy: false, }], ); @@ -309,7 +297,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_2).await.unwrap(), &[Contact::Accepted { user_id: user_1, - should_notify: false, busy: false, }] ); @@ -326,7 +313,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_1).await.unwrap(), &[Contact::Accepted { user_id: user_2, - should_notify: true, busy: false, }] ); @@ -339,7 +325,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_1).await.unwrap(), &[Contact::Accepted { user_id: user_2, - should_notify: false, busy: false, }] ); @@ -353,12 +338,10 @@ async fn test_add_contacts(db: &Arc) { &[ Contact::Accepted { user_id: user_2, - should_notify: false, busy: false, }, Contact::Accepted { user_id: user_3, - should_notify: false, busy: false, } ] @@ -367,7 +350,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_3).await.unwrap(), &[Contact::Accepted { user_id: user_1, - should_notify: false, busy: false, }], ); @@ -383,7 +365,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_2).await.unwrap(), &[Contact::Accepted { user_id: user_1, - should_notify: false, busy: false, }] ); @@ -391,7 +372,6 @@ async fn test_add_contacts(db: &Arc) { db.get_contacts(user_3).await.unwrap(), &[Contact::Accepted { user_id: user_1, - should_notify: false, busy: false, }], ); @@ -415,7 +395,6 @@ async fn test_metrics_id(db: &Arc) { NewUserParams { github_login: "person1".into(), github_user_id: 101, - invite_count: 5, }, ) .await @@ -431,7 +410,6 @@ async fn test_metrics_id(db: &Arc) { NewUserParams { github_login: "person2".into(), github_user_id: 102, - invite_count: 5, }, ) .await @@ -460,7 +438,6 @@ async fn test_project_count(db: &Arc) { NewUserParams { github_login: "admin".into(), github_user_id: 0, - invite_count: 0, }, ) .await @@ -472,7 +449,6 @@ async fn test_project_count(db: &Arc) { NewUserParams { github_login: "user".into(), github_user_id: 1, - invite_count: 0, }, ) .await @@ -554,7 +530,6 @@ async fn test_fuzzy_search_users() { NewUserParams { github_login: github_login.into(), github_user_id: i as i32, - invite_count: 0, }, ) .await @@ -596,7 +571,6 @@ async fn test_non_matching_release_channels(db: &Arc) { NewUserParams { github_login: "admin".into(), github_user_id: 0, - invite_count: 0, }, ) .await @@ -608,7 +582,6 @@ async fn test_non_matching_release_channels(db: &Arc) { NewUserParams { github_login: "user".into(), github_user_id: 1, - invite_count: 0, }, ) .await diff --git a/crates/collab/src/db/tests/feature_flag_tests.rs b/crates/collab/src/db/tests/feature_flag_tests.rs index 9d5f039747c18fb6cfae77191654ba5b4584e21e..0286a6308e4208919e564e9526505045f2ee31d1 100644 --- a/crates/collab/src/db/tests/feature_flag_tests.rs +++ b/crates/collab/src/db/tests/feature_flag_tests.rs @@ -18,7 +18,6 @@ async fn test_get_user_flags(db: &Arc) { NewUserParams { github_login: format!("user1"), github_user_id: 1, - invite_count: 0, }, ) .await @@ -32,7 +31,6 @@ async fn test_get_user_flags(db: &Arc) { NewUserParams { github_login: format!("user2"), github_user_id: 2, - invite_count: 0, }, ) .await diff --git a/crates/collab/src/db/tests/message_tests.rs b/crates/collab/src/db/tests/message_tests.rs index 272d8e01009ac8758e887d2c0ec4f04570464923..97b3142930abc5d5f641825f7844b63e8c79a5dc 100644 --- a/crates/collab/src/db/tests/message_tests.rs +++ b/crates/collab/src/db/tests/message_tests.rs @@ -1,7 +1,9 @@ +use super::new_test_user; use crate::{ - db::{ChannelRole, Database, MessageId, NewUserParams}, + db::{ChannelRole, Database, MessageId}, test_both_dbs, }; +use channel::mentions_to_proto; use std::sync::Arc; use time::OffsetDateTime; @@ -12,19 +14,7 @@ test_both_dbs!( ); async fn test_channel_message_retrieval(db: &Arc) { - let user = db - .create_user( - "user@example.com", - false, - NewUserParams { - github_login: "user".into(), - github_user_id: 1, - invite_count: 0, - }, - ) - .await - .unwrap() - .user_id; + let user = new_test_user(db, "user@example.com").await; let channel = db.create_channel("channel", None, user).await.unwrap(); let owner_id = db.create_server("test").await.unwrap().0 as u32; @@ -35,11 +25,18 @@ async fn test_channel_message_retrieval(db: &Arc) { let mut all_messages = Vec::new(); for i in 0..10 { all_messages.push( - db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i) - .await - .unwrap() - .0 - .to_proto(), + db.create_channel_message( + channel, + user, + &i.to_string(), + &[], + OffsetDateTime::now_utc(), + i, + ) + .await + .unwrap() + .message_id + .to_proto(), ); } @@ -74,99 +71,154 @@ test_both_dbs!( ); async fn test_channel_message_nonces(db: &Arc) { - let user = db - .create_user( - "user@example.com", - false, - NewUserParams { - github_login: "user".into(), - github_user_id: 1, - invite_count: 0, - }, - ) + let user_a = new_test_user(db, "user_a@example.com").await; + let user_b = new_test_user(db, "user_b@example.com").await; + let user_c = new_test_user(db, "user_c@example.com").await; + let channel = db.create_channel("channel", None, user_a).await.unwrap(); + db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member) .await - .unwrap() - .user_id; - let channel = db.create_channel("channel", None, user).await.unwrap(); - - let owner_id = db.create_server("test").await.unwrap().0 as u32; - - db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user) + .unwrap(); + db.invite_channel_member(channel, user_c, user_a, ChannelRole::Member) .await .unwrap(); - - let msg1_id = db - .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1) + db.respond_to_channel_invite(channel, user_b, true) .await .unwrap(); - let msg2_id = db - .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2) + db.respond_to_channel_invite(channel, user_c, true) .await .unwrap(); - let msg3_id = db - .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1) + + let owner_id = db.create_server("test").await.unwrap().0 as u32; + db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 0 }, user_a) .await .unwrap(); - let msg4_id = db - .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2) + db.join_channel_chat(channel, rpc::ConnectionId { owner_id, id: 1 }, user_b) .await .unwrap(); - assert_ne!(msg1_id, msg2_id); - assert_eq!(msg1_id, msg3_id); - assert_eq!(msg2_id, msg4_id); -} - -test_both_dbs!( - test_channel_message_new_notification, - test_channel_message_new_notification_postgres, - test_channel_message_new_notification_sqlite -); - -async fn test_channel_message_new_notification(db: &Arc) { - let user = db - .create_user( - "user_a@example.com", - false, - NewUserParams { - github_login: "user_a".into(), - github_user_id: 1, - invite_count: 0, - }, + // As user A, create messages that re-use the same nonces. The requests + // succeed, but return the same ids. + let id1 = db + .create_channel_message( + channel, + user_a, + "hi @user_b", + &mentions_to_proto(&[(3..10, user_b.to_proto())]), + OffsetDateTime::now_utc(), + 100, ) .await .unwrap() - .user_id; - let observer = db - .create_user( - "user_b@example.com", - false, - NewUserParams { - github_login: "user_b".into(), - github_user_id: 1, - invite_count: 0, - }, + .message_id; + let id2 = db + .create_channel_message( + channel, + user_a, + "hello, fellow users", + &mentions_to_proto(&[]), + OffsetDateTime::now_utc(), + 200, + ) + .await + .unwrap() + .message_id; + let id3 = db + .create_channel_message( + channel, + user_a, + "bye @user_c (same nonce as first message)", + &mentions_to_proto(&[(4..11, user_c.to_proto())]), + OffsetDateTime::now_utc(), + 100, + ) + .await + .unwrap() + .message_id; + let id4 = db + .create_channel_message( + channel, + user_a, + "omg (same nonce as second message)", + &mentions_to_proto(&[]), + OffsetDateTime::now_utc(), + 200, ) .await .unwrap() - .user_id; + .message_id; - let channel_1 = db.create_channel("channel", None, user).await.unwrap(); + // As a different user, reuse one of the same nonces. This request succeeds + // and returns a different id. + let id5 = db + .create_channel_message( + channel, + user_b, + "omg @user_a (same nonce as user_a's first message)", + &mentions_to_proto(&[(4..11, user_a.to_proto())]), + OffsetDateTime::now_utc(), + 100, + ) + .await + .unwrap() + .message_id; + + assert_ne!(id1, id2); + assert_eq!(id1, id3); + assert_eq!(id2, id4); + assert_ne!(id5, id1); + let messages = db + .get_channel_messages(channel, user_a, 5, None) + .await + .unwrap() + .into_iter() + .map(|m| (m.id, m.body, m.mentions)) + .collect::>(); + assert_eq!( + messages, + &[ + ( + id1.to_proto(), + "hi @user_b".into(), + mentions_to_proto(&[(3..10, user_b.to_proto())]), + ), + ( + id2.to_proto(), + "hello, fellow users".into(), + mentions_to_proto(&[]) + ), + ( + id5.to_proto(), + "omg @user_a (same nonce as user_a's first message)".into(), + mentions_to_proto(&[(4..11, user_a.to_proto())]), + ), + ] + ); +} + +test_both_dbs!( + test_unseen_channel_messages, + test_unseen_channel_messages_postgres, + test_unseen_channel_messages_sqlite +); + +async fn test_unseen_channel_messages(db: &Arc) { + let user = new_test_user(db, "user_a@example.com").await; + let observer = new_test_user(db, "user_b@example.com").await; + + let channel_1 = db.create_channel("channel", None, user).await.unwrap(); let channel_2 = db.create_channel("channel-2", None, user).await.unwrap(); db.invite_channel_member(channel_1, observer, user, ChannelRole::Member) .await .unwrap(); - - db.respond_to_channel_invite(channel_1, observer, true) + db.invite_channel_member(channel_2, observer, user, ChannelRole::Member) .await .unwrap(); - db.invite_channel_member(channel_2, observer, user, ChannelRole::Member) + db.respond_to_channel_invite(channel_1, observer, true) .await .unwrap(); - db.respond_to_channel_invite(channel_2, observer, true) .await .unwrap(); @@ -179,28 +231,31 @@ async fn test_channel_message_new_notification(db: &Arc) { .unwrap(); let _ = db - .create_channel_message(channel_1, user, "1_1", OffsetDateTime::now_utc(), 1) + .create_channel_message(channel_1, user, "1_1", &[], OffsetDateTime::now_utc(), 1) .await .unwrap(); - let (second_message, _, _) = db - .create_channel_message(channel_1, user, "1_2", OffsetDateTime::now_utc(), 2) + let second_message = db + .create_channel_message(channel_1, user, "1_2", &[], OffsetDateTime::now_utc(), 2) .await - .unwrap(); + .unwrap() + .message_id; - let (third_message, _, _) = db - .create_channel_message(channel_1, user, "1_3", OffsetDateTime::now_utc(), 3) + let third_message = db + .create_channel_message(channel_1, user, "1_3", &[], OffsetDateTime::now_utc(), 3) .await - .unwrap(); + .unwrap() + .message_id; db.join_channel_chat(channel_2, user_connection_id, user) .await .unwrap(); - let (fourth_message, _, _) = db - .create_channel_message(channel_2, user, "2_1", OffsetDateTime::now_utc(), 4) + let fourth_message = db + .create_channel_message(channel_2, user, "2_1", &[], OffsetDateTime::now_utc(), 4) .await - .unwrap(); + .unwrap() + .message_id; // Check that observer has new messages let unseen_messages = db @@ -295,3 +350,96 @@ async fn test_channel_message_new_notification(db: &Arc) { }] ); } + +test_both_dbs!( + test_channel_message_mentions, + test_channel_message_mentions_postgres, + test_channel_message_mentions_sqlite +); + +async fn test_channel_message_mentions(db: &Arc) { + let user_a = new_test_user(db, "user_a@example.com").await; + let user_b = new_test_user(db, "user_b@example.com").await; + let user_c = new_test_user(db, "user_c@example.com").await; + + let channel = db.create_channel("channel", None, user_a).await.unwrap(); + db.invite_channel_member(channel, user_b, user_a, ChannelRole::Member) + .await + .unwrap(); + db.respond_to_channel_invite(channel, user_b, true) + .await + .unwrap(); + + let owner_id = db.create_server("test").await.unwrap().0 as u32; + let connection_id = rpc::ConnectionId { owner_id, id: 0 }; + db.join_channel_chat(channel, connection_id, user_a) + .await + .unwrap(); + + db.create_channel_message( + channel, + user_a, + "hi @user_b and @user_c", + &mentions_to_proto(&[(3..10, user_b.to_proto()), (15..22, user_c.to_proto())]), + OffsetDateTime::now_utc(), + 1, + ) + .await + .unwrap(); + db.create_channel_message( + channel, + user_a, + "bye @user_c", + &mentions_to_proto(&[(4..11, user_c.to_proto())]), + OffsetDateTime::now_utc(), + 2, + ) + .await + .unwrap(); + db.create_channel_message( + channel, + user_a, + "umm", + &mentions_to_proto(&[]), + OffsetDateTime::now_utc(), + 3, + ) + .await + .unwrap(); + db.create_channel_message( + channel, + user_a, + "@user_b, stop.", + &mentions_to_proto(&[(0..7, user_b.to_proto())]), + OffsetDateTime::now_utc(), + 4, + ) + .await + .unwrap(); + + let messages = db + .get_channel_messages(channel, user_b, 5, None) + .await + .unwrap() + .into_iter() + .map(|m| (m.body, m.mentions)) + .collect::>(); + assert_eq!( + &messages, + &[ + ( + "hi @user_b and @user_c".into(), + mentions_to_proto(&[(3..10, user_b.to_proto()), (15..22, user_c.to_proto())]), + ), + ( + "bye @user_c".into(), + mentions_to_proto(&[(4..11, user_c.to_proto())]), + ), + ("umm".into(), mentions_to_proto(&[]),), + ( + "@user_b, stop.".into(), + mentions_to_proto(&[(0..7, user_b.to_proto())]), + ), + ] + ); +} diff --git a/crates/collab/src/lib.rs b/crates/collab/src/lib.rs index 13fb8ed0ebc2c27a8fb63ba3ec5485f74f6e4390..85216525b0018c6d051c55a5882af8445f45c7d0 100644 --- a/crates/collab/src/lib.rs +++ b/crates/collab/src/lib.rs @@ -119,7 +119,9 @@ impl AppState { pub async fn new(config: Config) -> Result> { let mut db_options = db::ConnectOptions::new(config.database_url.clone()); db_options.max_connections(config.database_max_connections); - let db = Database::new(db_options, Executor::Production).await?; + let mut db = Database::new(db_options, Executor::Production).await?; + db.initialize_notification_kinds().await?; + let live_kit_client = if let Some(((server, key), secret)) = config .live_kit_server .as_ref() diff --git a/crates/collab/src/rpc.rs b/crates/collab/src/rpc.rs index 15ea3b24e133912d030ff15a850c5c3aae70bea8..5a29861351daf1ff91ecb382e8a4ad7614c01a66 100644 --- a/crates/collab/src/rpc.rs +++ b/crates/collab/src/rpc.rs @@ -3,8 +3,8 @@ mod connection_pool; use crate::{ auth, db::{ - self, BufferId, ChannelId, ChannelVisibility, ChannelsForUser, Database, MessageId, - ProjectId, RoomId, ServerId, User, UserId, + self, BufferId, ChannelId, ChannelVisibility, ChannelsForUser, CreatedChannelMessage, + Database, MessageId, NotificationId, ProjectId, RoomId, ServerId, User, UserId, }, executor::Executor, AppState, Result, @@ -70,6 +70,7 @@ pub const CLEANUP_TIMEOUT: Duration = Duration::from_secs(10); const MESSAGE_COUNT_PER_PAGE: usize = 100; const MAX_MESSAGE_LEN: usize = 1024; +const NOTIFICATION_COUNT_PER_PAGE: usize = 50; lazy_static! { static ref METRIC_CONNECTIONS: IntGauge = @@ -270,6 +271,9 @@ impl Server { .add_request_handler(send_channel_message) .add_request_handler(remove_channel_message) .add_request_handler(get_channel_messages) + .add_request_handler(get_channel_messages_by_id) + .add_request_handler(get_notifications) + .add_request_handler(mark_notification_as_read) .add_request_handler(link_channel) .add_request_handler(unlink_channel) .add_request_handler(move_channel) @@ -389,7 +393,7 @@ impl Server { let contacts = app_state.db.get_contacts(user_id).await.trace_err(); if let Some((busy, contacts)) = busy.zip(contacts) { let pool = pool.lock(); - let updated_contact = contact_for_user(user_id, false, busy, &pool); + let updated_contact = contact_for_user(user_id, busy, &pool); for contact in contacts { if let db::Contact::Accepted { user_id: contact_user_id, @@ -583,7 +587,7 @@ impl Server { let (contacts, channels_for_user, channel_invites) = future::try_join3( this.app_state.db.get_contacts(user_id), this.app_state.db.get_channels_for_user(user_id), - this.app_state.db.get_channel_invites_for_user(user_id) + this.app_state.db.get_channel_invites_for_user(user_id), ).await?; { @@ -689,7 +693,7 @@ impl Server { if let Some(user) = self.app_state.db.get_user_by_id(inviter_id).await? { if let Some(code) = &user.invite_code { let pool = self.connection_pool.lock(); - let invitee_contact = contact_for_user(invitee_id, true, false, &pool); + let invitee_contact = contact_for_user(invitee_id, false, &pool); for connection_id in pool.user_connection_ids(inviter_id) { self.peer.send( connection_id, @@ -2063,7 +2067,7 @@ async fn request_contact( return Err(anyhow!("cannot add yourself as a contact"))?; } - session + let notifications = session .db() .await .send_contact_request(requester_id, responder_id) @@ -2086,16 +2090,14 @@ async fn request_contact( .incoming_requests .push(proto::IncomingContactRequest { requester_id: requester_id.to_proto(), - should_notify: true, }); - for connection_id in session - .connection_pool() - .await - .user_connection_ids(responder_id) - { + let connection_pool = session.connection_pool().await; + for connection_id in connection_pool.user_connection_ids(responder_id) { session.peer.send(connection_id, update.clone())?; } + send_notifications(&*connection_pool, &session.peer, notifications); + response.send(proto::Ack {})?; Ok(()) } @@ -2114,7 +2116,8 @@ async fn respond_to_contact_request( } else { let accept = request.response == proto::ContactRequestResponse::Accept as i32; - db.respond_to_contact_request(responder_id, requester_id, accept) + let notifications = db + .respond_to_contact_request(responder_id, requester_id, accept) .await?; let requester_busy = db.is_user_busy(requester_id).await?; let responder_busy = db.is_user_busy(responder_id).await?; @@ -2125,7 +2128,7 @@ async fn respond_to_contact_request( if accept { update .contacts - .push(contact_for_user(requester_id, false, requester_busy, &pool)); + .push(contact_for_user(requester_id, requester_busy, &pool)); } update .remove_incoming_requests @@ -2139,14 +2142,17 @@ async fn respond_to_contact_request( if accept { update .contacts - .push(contact_for_user(responder_id, true, responder_busy, &pool)); + .push(contact_for_user(responder_id, responder_busy, &pool)); } update .remove_outgoing_requests .push(responder_id.to_proto()); + for connection_id in pool.user_connection_ids(requester_id) { session.peer.send(connection_id, update.clone())?; } + + send_notifications(&*pool, &session.peer, notifications); } response.send(proto::Ack {})?; @@ -2161,7 +2167,8 @@ async fn remove_contact( let requester_id = session.user_id; let responder_id = UserId::from_proto(request.user_id); let db = session.db().await; - let contact_accepted = db.remove_contact(requester_id, responder_id).await?; + let (contact_accepted, deleted_notification_id) = + db.remove_contact(requester_id, responder_id).await?; let pool = session.connection_pool().await; // Update outgoing contact requests of requester @@ -2188,6 +2195,14 @@ async fn remove_contact( } for connection_id in pool.user_connection_ids(responder_id) { session.peer.send(connection_id, update.clone())?; + if let Some(notification_id) = deleted_notification_id { + session.peer.send( + connection_id, + proto::DeleteNotification { + notification_id: notification_id.to_proto(), + }, + )?; + } } response.send(proto::Ack {})?; @@ -2282,13 +2297,14 @@ async fn invite_channel_member( let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); let invitee_id = UserId::from_proto(request.user_id); - db.invite_channel_member( - channel_id, - invitee_id, - session.user_id, - request.role().into(), - ) - .await?; + let notifications = db + .invite_channel_member( + channel_id, + invitee_id, + session.user_id, + request.role().into(), + ) + .await?; let channel = db.get_channel(channel_id, session.user_id).await?; @@ -2298,14 +2314,14 @@ async fn invite_channel_member( visibility: channel.visibility.into(), name: channel.name, }); - for connection_id in session - .connection_pool() - .await - .user_connection_ids(invitee_id) - { + + let pool = session.connection_pool().await; + for connection_id in pool.user_connection_ids(invitee_id) { session.peer.send(connection_id, update.clone())?; } + send_notifications(&*pool, &session.peer, notifications); + response.send(proto::Ack {})?; Ok(()) } @@ -2319,7 +2335,8 @@ async fn remove_channel_member( let channel_id = ChannelId::from_proto(request.channel_id); let member_id = UserId::from_proto(request.user_id); - db.remove_channel_member(channel_id, member_id, session.user_id) + let removed_notification_id = db + .remove_channel_member(channel_id, member_id, session.user_id) .await?; let mut update = proto::UpdateChannels::default(); @@ -2330,7 +2347,18 @@ async fn remove_channel_member( .await .user_connection_ids(member_id) { - session.peer.send(connection_id, update.clone())?; + session.peer.send(connection_id, update.clone()).trace_err(); + if let Some(notification_id) = removed_notification_id { + session + .peer + .send( + connection_id, + proto::DeleteNotification { + notification_id: notification_id.to_proto(), + }, + ) + .trace_err(); + } } response.send(proto::Ack {})?; @@ -2592,7 +2620,8 @@ async fn respond_to_channel_invite( ) -> Result<()> { let db = session.db().await; let channel_id = ChannelId::from_proto(request.channel_id); - db.respond_to_channel_invite(channel_id, session.user_id, request.accept) + let notifications = db + .respond_to_channel_invite(channel_id, session.user_id, request.accept) .await?; if request.accept { @@ -2604,6 +2633,12 @@ async fn respond_to_channel_invite( .push(channel_id.to_proto()); session.peer.send(session.connection_id, update)?; } + + send_notifications( + &*session.connection_pool().await, + &session.peer, + notifications, + ); response.send(proto::Ack {})?; Ok(()) @@ -2891,6 +2926,29 @@ fn channel_buffer_updated( }); } +fn send_notifications( + connection_pool: &ConnectionPool, + peer: &Peer, + notifications: db::NotificationBatch, +) { + for (user_id, notification) in notifications { + for connection_id in connection_pool.user_connection_ids(user_id) { + if let Err(error) = peer.send( + connection_id, + proto::AddNotification { + notification: Some(notification.clone()), + }, + ) { + tracing::error!( + "failed to send notification to {:?} {}", + connection_id, + error + ); + } + } + } +} + async fn send_channel_message( request: proto::SendChannelMessage, response: Response, @@ -2905,19 +2963,27 @@ async fn send_channel_message( return Err(anyhow!("message can't be blank"))?; } + // TODO: adjust mentions if body is trimmed + let timestamp = OffsetDateTime::now_utc(); let nonce = request .nonce .ok_or_else(|| anyhow!("nonce can't be blank"))?; let channel_id = ChannelId::from_proto(request.channel_id); - let (message_id, connection_ids, non_participants) = session + let CreatedChannelMessage { + message_id, + participant_connection_ids, + channel_members, + notifications, + } = session .db() .await .create_channel_message( channel_id, session.user_id, &body, + &request.mentions, timestamp, nonce.clone().into(), ) @@ -2926,18 +2992,23 @@ async fn send_channel_message( sender_id: session.user_id.to_proto(), id: message_id.to_proto(), body, + mentions: request.mentions, timestamp: timestamp.unix_timestamp() as u64, nonce: Some(nonce), }; - broadcast(Some(session.connection_id), connection_ids, |connection| { - session.peer.send( - connection, - proto::ChannelMessageSent { - channel_id: channel_id.to_proto(), - message: Some(message.clone()), - }, - ) - }); + broadcast( + Some(session.connection_id), + participant_connection_ids, + |connection| { + session.peer.send( + connection, + proto::ChannelMessageSent { + channel_id: channel_id.to_proto(), + message: Some(message.clone()), + }, + ) + }, + ); response.send(proto::SendChannelMessageResponse { message: Some(message), })?; @@ -2945,7 +3016,7 @@ async fn send_channel_message( let pool = &*session.connection_pool().await; broadcast( None, - non_participants + channel_members .iter() .flat_map(|user_id| pool.user_connection_ids(*user_id)), |peer_id| { @@ -2961,6 +3032,7 @@ async fn send_channel_message( ) }, ); + send_notifications(pool, &session.peer, notifications); Ok(()) } @@ -2990,11 +3062,16 @@ async fn acknowledge_channel_message( ) -> Result<()> { let channel_id = ChannelId::from_proto(request.channel_id); let message_id = MessageId::from_proto(request.message_id); - session + let notifications = session .db() .await .observe_channel_message(channel_id, session.user_id, message_id) .await?; + send_notifications( + &*session.connection_pool().await, + &session.peer, + notifications, + ); Ok(()) } @@ -3069,6 +3146,72 @@ async fn get_channel_messages( Ok(()) } +async fn get_channel_messages_by_id( + request: proto::GetChannelMessagesById, + response: Response, + session: Session, +) -> Result<()> { + let message_ids = request + .message_ids + .iter() + .map(|id| MessageId::from_proto(*id)) + .collect::>(); + let messages = session + .db() + .await + .get_channel_messages_by_id(session.user_id, &message_ids) + .await?; + response.send(proto::GetChannelMessagesResponse { + done: messages.len() < MESSAGE_COUNT_PER_PAGE, + messages, + })?; + Ok(()) +} + +async fn get_notifications( + request: proto::GetNotifications, + response: Response, + session: Session, +) -> Result<()> { + let notifications = session + .db() + .await + .get_notifications( + session.user_id, + NOTIFICATION_COUNT_PER_PAGE, + request + .before_id + .map(|id| db::NotificationId::from_proto(id)), + ) + .await?; + response.send(proto::GetNotificationsResponse { + done: notifications.len() < NOTIFICATION_COUNT_PER_PAGE, + notifications, + })?; + Ok(()) +} + +async fn mark_notification_as_read( + request: proto::MarkNotificationRead, + response: Response, + session: Session, +) -> Result<()> { + let database = &session.db().await; + let notifications = database + .mark_notification_as_read_by_id( + session.user_id, + NotificationId::from_proto(request.notification_id), + ) + .await?; + send_notifications( + &*session.connection_pool().await, + &session.peer, + notifications, + ); + response.send(proto::Ack {})?; + Ok(()) +} + async fn update_diff_base(request: proto::UpdateDiffBase, session: Session) -> Result<()> { let project_id = ProjectId::from_proto(request.project_id); let project_connection_ids = session @@ -3197,42 +3340,28 @@ fn build_initial_contacts_update( for contact in contacts { match contact { - db::Contact::Accepted { - user_id, - should_notify, - busy, - } => { - update - .contacts - .push(contact_for_user(user_id, should_notify, busy, &pool)); + db::Contact::Accepted { user_id, busy } => { + update.contacts.push(contact_for_user(user_id, busy, &pool)); } db::Contact::Outgoing { user_id } => update.outgoing_requests.push(user_id.to_proto()), - db::Contact::Incoming { - user_id, - should_notify, - } => update - .incoming_requests - .push(proto::IncomingContactRequest { - requester_id: user_id.to_proto(), - should_notify, - }), + db::Contact::Incoming { user_id } => { + update + .incoming_requests + .push(proto::IncomingContactRequest { + requester_id: user_id.to_proto(), + }) + } } } update } -fn contact_for_user( - user_id: UserId, - should_notify: bool, - busy: bool, - pool: &ConnectionPool, -) -> proto::Contact { +fn contact_for_user(user_id: UserId, busy: bool, pool: &ConnectionPool) -> proto::Contact { proto::Contact { user_id: user_id.to_proto(), online: pool.is_user_online(user_id), busy, - should_notify, } } @@ -3293,7 +3422,7 @@ async fn update_user_contacts(user_id: UserId, session: &Session) -> Result<()> let busy = db.is_user_busy(user_id).await?; let pool = session.connection_pool().await; - let updated_contact = contact_for_user(user_id, false, busy, &pool); + let updated_contact = contact_for_user(user_id, busy, &pool); for contact in contacts { if let db::Contact::Accepted { user_id: contact_user_id, diff --git a/crates/collab/src/tests.rs b/crates/collab/src/tests.rs index e78bbe3466318cfc44fbcf298cef65a86350a0b8..139910e1f6f25281139865f0b80b8e20e15648d0 100644 --- a/crates/collab/src/tests.rs +++ b/crates/collab/src/tests.rs @@ -6,6 +6,7 @@ mod channel_message_tests; mod channel_tests; mod following_tests; mod integration_tests; +mod notification_tests; mod random_channel_buffer_tests; mod random_project_collaboration_tests; mod randomized_test_helpers; diff --git a/crates/collab/src/tests/channel_message_tests.rs b/crates/collab/src/tests/channel_message_tests.rs index 0fc3b085edde00cbfe552352fe5590753f686134..918eb053d3d52507c810966a7f59941337b8aff4 100644 --- a/crates/collab/src/tests/channel_message_tests.rs +++ b/crates/collab/src/tests/channel_message_tests.rs @@ -1,27 +1,30 @@ use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer}; -use channel::{ChannelChat, ChannelMessageId}; +use channel::{ChannelChat, ChannelMessageId, MessageParams}; use collab_ui::chat_panel::ChatPanel; use gpui::{executor::Deterministic, BorrowAppContext, ModelHandle, TestAppContext}; +use rpc::Notification; use std::sync::Arc; use workspace::dock::Panel; #[gpui::test] async fn test_basic_channel_messages( deterministic: Arc, - cx_a: &mut TestAppContext, - cx_b: &mut TestAppContext, + mut cx_a: &mut TestAppContext, + mut cx_b: &mut TestAppContext, + mut cx_c: &mut TestAppContext, ) { deterministic.forbid_parking(); let mut server = TestServer::start(&deterministic).await; let client_a = server.create_client(cx_a, "user_a").await; let client_b = server.create_client(cx_b, "user_b").await; + let client_c = server.create_client(cx_c, "user_c").await; let channel_id = server .make_channel( "the-channel", None, (&client_a, cx_a), - &mut [(&client_b, cx_b)], + &mut [(&client_b, cx_b), (&client_c, cx_c)], ) .await; @@ -36,8 +39,17 @@ async fn test_basic_channel_messages( .await .unwrap(); - channel_chat_a - .update(cx_a, |c, cx| c.send_message("one".into(), cx).unwrap()) + let message_id = channel_chat_a + .update(cx_a, |c, cx| { + c.send_message( + MessageParams { + text: "hi @user_c!".into(), + mentions: vec![(3..10, client_c.id())], + }, + cx, + ) + .unwrap() + }) .await .unwrap(); channel_chat_a @@ -52,15 +64,55 @@ async fn test_basic_channel_messages( .unwrap(); deterministic.run_until_parked(); - channel_chat_a.update(cx_a, |c, _| { + + let channel_chat_c = client_c + .channel_store() + .update(cx_c, |store, cx| store.open_channel_chat(channel_id, cx)) + .await + .unwrap(); + + for (chat, cx) in [ + (&channel_chat_a, &mut cx_a), + (&channel_chat_b, &mut cx_b), + (&channel_chat_c, &mut cx_c), + ] { + chat.update(*cx, |c, _| { + assert_eq!( + c.messages() + .iter() + .map(|m| (m.body.as_str(), m.mentions.as_slice())) + .collect::>(), + vec![ + ("hi @user_c!", [(3..10, client_c.id())].as_slice()), + ("two", &[]), + ("three", &[]) + ], + "results for user {}", + c.client().id(), + ); + }); + } + + client_c.notification_store().update(cx_c, |store, _| { + assert_eq!(store.notification_count(), 2); + assert_eq!(store.unread_notification_count(), 1); assert_eq!( - c.messages() - .iter() - .map(|m| m.body.as_str()) - .collect::>(), - vec!["one", "two", "three"] + store.notification_at(0).unwrap().notification, + Notification::ChannelMessageMention { + message_id, + sender_id: client_a.id(), + channel_id, + } ); - }) + assert_eq!( + store.notification_at(1).unwrap().notification, + Notification::ChannelInvitation { + channel_id, + channel_name: "the-channel".to_string(), + inviter_id: client_a.id() + } + ); + }); } #[gpui::test] @@ -280,7 +332,7 @@ async fn test_channel_message_changes( chat_panel_b .update(cx_b, |chat_panel, cx| { chat_panel.set_active(true, cx); - chat_panel.select_channel(channel_id, cx) + chat_panel.select_channel(channel_id, None, cx) }) .await .unwrap(); diff --git a/crates/collab/src/tests/channel_tests.rs b/crates/collab/src/tests/channel_tests.rs index 1bb8c92ac80ac45b78ccff76de8f15c1468acaef..54a958e71c62a5a8f6051b0def9388f070271ca5 100644 --- a/crates/collab/src/tests/channel_tests.rs +++ b/crates/collab/src/tests/channel_tests.rs @@ -125,8 +125,8 @@ async fn test_core_channels( // Client B accepts the invitation. client_b .channel_store() - .update(cx_b, |channels, _| { - channels.respond_to_channel_invite(channel_a_id, true) + .update(cx_b, |channels, cx| { + channels.respond_to_channel_invite(channel_a_id, true, cx) }) .await .unwrap(); @@ -884,8 +884,8 @@ async fn test_lost_channel_creation( // Client B accepts the invite client_b .channel_store() - .update(cx_b, |channel_store, _| { - channel_store.respond_to_channel_invite(channel_id, true) + .update(cx_b, |channel_store, cx| { + channel_store.respond_to_channel_invite(channel_id, true, cx) }) .await .unwrap(); diff --git a/crates/collab/src/tests/following_tests.rs b/crates/collab/src/tests/following_tests.rs index f3857e3db37343aee1d4ba68116a0bc236f61e98..a28f2ae87f0984241ca7df30fac0807d4e0fa31b 100644 --- a/crates/collab/src/tests/following_tests.rs +++ b/crates/collab/src/tests/following_tests.rs @@ -1,6 +1,6 @@ use crate::{rpc::RECONNECT_TIMEOUT, tests::TestServer}; use call::ActiveCall; -use collab_ui::project_shared_notification::ProjectSharedNotification; +use collab_ui::notifications::project_shared_notification::ProjectSharedNotification; use editor::{Editor, ExcerptRange, MultiBuffer}; use gpui::{executor::Deterministic, geometry::vector::vec2f, TestAppContext, ViewHandle}; use live_kit_client::MacOSDisplay; diff --git a/crates/collab/src/tests/notification_tests.rs b/crates/collab/src/tests/notification_tests.rs new file mode 100644 index 0000000000000000000000000000000000000000..1114470449341de4310205c9d702e2833f62370f --- /dev/null +++ b/crates/collab/src/tests/notification_tests.rs @@ -0,0 +1,159 @@ +use crate::tests::TestServer; +use gpui::{executor::Deterministic, TestAppContext}; +use notifications::NotificationEvent; +use parking_lot::Mutex; +use rpc::{proto, Notification}; +use std::sync::Arc; + +#[gpui::test] +async fn test_notifications( + deterministic: Arc, + cx_a: &mut TestAppContext, + cx_b: &mut TestAppContext, +) { + deterministic.forbid_parking(); + let mut server = TestServer::start(&deterministic).await; + let client_a = server.create_client(cx_a, "user_a").await; + let client_b = server.create_client(cx_b, "user_b").await; + + let notification_events_a = Arc::new(Mutex::new(Vec::new())); + let notification_events_b = Arc::new(Mutex::new(Vec::new())); + client_a.notification_store().update(cx_a, |_, cx| { + let events = notification_events_a.clone(); + cx.subscribe(&cx.handle(), move |_, _, event, _| { + events.lock().push(event.clone()); + }) + .detach() + }); + client_b.notification_store().update(cx_b, |_, cx| { + let events = notification_events_b.clone(); + cx.subscribe(&cx.handle(), move |_, _, event, _| { + events.lock().push(event.clone()); + }) + .detach() + }); + + // Client A sends a contact request to client B. + client_a + .user_store() + .update(cx_a, |store, cx| store.request_contact(client_b.id(), cx)) + .await + .unwrap(); + + // Client B receives a contact request notification and responds to the + // request, accepting it. + deterministic.run_until_parked(); + client_b.notification_store().update(cx_b, |store, cx| { + assert_eq!(store.notification_count(), 1); + assert_eq!(store.unread_notification_count(), 1); + + let entry = store.notification_at(0).unwrap(); + assert_eq!( + entry.notification, + Notification::ContactRequest { + sender_id: client_a.id() + } + ); + assert!(!entry.is_read); + assert_eq!( + ¬ification_events_b.lock()[0..], + &[ + NotificationEvent::NewNotification { + entry: entry.clone(), + }, + NotificationEvent::NotificationsUpdated { + old_range: 0..0, + new_count: 1 + } + ] + ); + + store.respond_to_notification(entry.notification.clone(), true, cx); + }); + + // Client B sees the notification is now read, and that they responded. + deterministic.run_until_parked(); + client_b.notification_store().read_with(cx_b, |store, _| { + assert_eq!(store.notification_count(), 1); + assert_eq!(store.unread_notification_count(), 0); + + let entry = store.notification_at(0).unwrap(); + assert!(entry.is_read); + assert_eq!(entry.response, Some(true)); + assert_eq!( + ¬ification_events_b.lock()[2..], + &[ + NotificationEvent::NotificationRead { + entry: entry.clone(), + }, + NotificationEvent::NotificationsUpdated { + old_range: 0..1, + new_count: 1 + } + ] + ); + }); + + // Client A receives a notification that client B accepted their request. + client_a.notification_store().read_with(cx_a, |store, _| { + assert_eq!(store.notification_count(), 1); + assert_eq!(store.unread_notification_count(), 1); + + let entry = store.notification_at(0).unwrap(); + assert_eq!( + entry.notification, + Notification::ContactRequestAccepted { + responder_id: client_b.id() + } + ); + assert!(!entry.is_read); + }); + + // Client A creates a channel and invites client B to be a member. + let channel_id = client_a + .channel_store() + .update(cx_a, |store, cx| { + store.create_channel("the-channel", None, cx) + }) + .await + .unwrap(); + client_a + .channel_store() + .update(cx_a, |store, cx| { + store.invite_member(channel_id, client_b.id(), proto::ChannelRole::Member, cx) + }) + .await + .unwrap(); + + // Client B receives a channel invitation notification and responds to the + // invitation, accepting it. + deterministic.run_until_parked(); + client_b.notification_store().update(cx_b, |store, cx| { + assert_eq!(store.notification_count(), 2); + assert_eq!(store.unread_notification_count(), 1); + + let entry = store.notification_at(0).unwrap(); + assert_eq!( + entry.notification, + Notification::ChannelInvitation { + channel_id, + channel_name: "the-channel".to_string(), + inviter_id: client_a.id() + } + ); + assert!(!entry.is_read); + + store.respond_to_notification(entry.notification.clone(), true, cx); + }); + + // Client B sees the notification is now read, and that they responded. + deterministic.run_until_parked(); + client_b.notification_store().read_with(cx_b, |store, _| { + assert_eq!(store.notification_count(), 2); + assert_eq!(store.unread_notification_count(), 0); + + let entry = store.notification_at(0).unwrap(); + assert!(entry.is_read); + assert_eq!(entry.response, Some(true)); + }); +} diff --git a/crates/collab/src/tests/randomized_test_helpers.rs b/crates/collab/src/tests/randomized_test_helpers.rs index 39598bdaf9d219f78e5a88b07d159798a167237b..1cec9452823ac2f96389f3e69f4876605d4092fd 100644 --- a/crates/collab/src/tests/randomized_test_helpers.rs +++ b/crates/collab/src/tests/randomized_test_helpers.rs @@ -208,8 +208,7 @@ impl TestPlan { false, NewUserParams { github_login: username.clone(), - github_user_id: (ix + 1) as i32, - invite_count: 0, + github_user_id: ix as i32, }, ) .await diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index c37ea19d528d9f30ef4cc69de31b3c3d2370a86a..f7c4fa4146cd8a6e665048ce9a8f3e1b4c391a25 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -16,6 +16,7 @@ use futures::{channel::oneshot, StreamExt as _}; use gpui::{executor::Deterministic, ModelHandle, Task, TestAppContext, WindowHandle}; use language::LanguageRegistry; use node_runtime::FakeNodeRuntime; +use notifications::NotificationStore; use parking_lot::Mutex; use project::{Project, WorktreeId}; use rpc::{proto::ChannelRole, RECEIVE_TIMEOUT}; @@ -46,6 +47,7 @@ pub struct TestClient { pub username: String, pub app_state: Arc, channel_store: ModelHandle, + notification_store: ModelHandle, state: RefCell, } @@ -138,7 +140,6 @@ impl TestServer { NewUserParams { github_login: name.into(), github_user_id: 0, - invite_count: 0, }, ) .await @@ -231,7 +232,8 @@ impl TestServer { workspace::init(app_state.clone(), cx); audio::init((), cx); call::init(client.clone(), user_store.clone(), cx); - channel::init(&client, user_store, cx); + channel::init(&client, user_store.clone(), cx); + notifications::init(client.clone(), user_store, cx); }); client @@ -243,6 +245,7 @@ impl TestServer { app_state, username: name.to_string(), channel_store: cx.read(ChannelStore::global).clone(), + notification_store: cx.read(NotificationStore::global).clone(), state: Default::default(), }; client.wait_for_current_user(cx).await; @@ -338,8 +341,8 @@ impl TestServer { member_cx .read(ChannelStore::global) - .update(*member_cx, |channels, _| { - channels.respond_to_channel_invite(channel_id, true) + .update(*member_cx, |channels, cx| { + channels.respond_to_channel_invite(channel_id, true, cx) }) .await .unwrap(); @@ -448,6 +451,10 @@ impl TestClient { &self.channel_store } + pub fn notification_store(&self) -> &ModelHandle { + &self.notification_store + } + pub fn user_store(&self) -> &ModelHandle { &self.app_state.user_store } @@ -630,8 +637,8 @@ impl TestClient { other_cx .read(ChannelStore::global) - .update(other_cx, |channel_store, _| { - channel_store.respond_to_channel_invite(channel, true) + .update(other_cx, |channel_store, cx| { + channel_store.respond_to_channel_invite(channel, true, cx) }) .await .unwrap(); diff --git a/crates/collab_ui/Cargo.toml b/crates/collab_ui/Cargo.toml index 98790778c98d69afa90743f8e40d94aa397cf886..8aee0da8dd5978f2a97fb6e52f406f3de7cf44ba 100644 --- a/crates/collab_ui/Cargo.toml +++ b/crates/collab_ui/Cargo.toml @@ -37,10 +37,12 @@ fuzzy = { path = "../fuzzy" } gpui = { path = "../gpui" } language = { path = "../language" } menu = { path = "../menu" } +notifications = { path = "../notifications" } rich_text = { path = "../rich_text" } picker = { path = "../picker" } project = { path = "../project" } -recent_projects = {path = "../recent_projects"} +recent_projects = { path = "../recent_projects" } +rpc = { path = "../rpc" } settings = { path = "../settings" } feature_flags = {path = "../feature_flags"} theme = { path = "../theme" } @@ -52,6 +54,7 @@ zed-actions = {path = "../zed-actions"} anyhow.workspace = true futures.workspace = true +lazy_static.workspace = true log.workspace = true schemars.workspace = true postage.workspace = true @@ -65,7 +68,12 @@ client = { path = "../client", features = ["test-support"] } collections = { path = "../collections", features = ["test-support"] } editor = { path = "../editor", features = ["test-support"] } gpui = { path = "../gpui", features = ["test-support"] } +notifications = { path = "../notifications", features = ["test-support"] } project = { path = "../project", features = ["test-support"] } +rpc = { path = "../rpc", features = ["test-support"] } settings = { path = "../settings", features = ["test-support"] } util = { path = "../util", features = ["test-support"] } workspace = { path = "../workspace", features = ["test-support"] } + +pretty_assertions.workspace = true +tree-sitter-markdown.workspace = true diff --git a/crates/collab_ui/src/chat_panel.rs b/crates/collab_ui/src/chat_panel.rs index a8c4006cb8ce7c18e3a2eaa09eb1dafc98becd26..5b922037c51c1e5995d2eff31975915f9cd679ab 100644 --- a/crates/collab_ui/src/chat_panel.rs +++ b/crates/collab_ui/src/chat_panel.rs @@ -1,4 +1,6 @@ -use crate::{channel_view::ChannelView, ChatPanelSettings}; +use crate::{ + channel_view::ChannelView, is_channels_feature_enabled, render_avatar, ChatPanelSettings, +}; use anyhow::Result; use call::ActiveCall; use channel::{ChannelChat, ChannelChatEvent, ChannelMessageId, ChannelStore}; @@ -6,18 +8,18 @@ use client::Client; use collections::HashMap; use db::kvp::KEY_VALUE_STORE; use editor::Editor; -use feature_flags::{ChannelsAlpha, FeatureFlagAppExt}; use gpui::{ actions, elements::*, platform::{CursorStyle, MouseButton}, serde_json, views::{ItemType, Select, SelectStyle}, - AnyViewHandle, AppContext, AsyncAppContext, Entity, ImageData, ModelHandle, Subscription, Task, - View, ViewContext, ViewHandle, WeakViewHandle, + AnyViewHandle, AppContext, AsyncAppContext, Entity, ModelHandle, Subscription, Task, View, + ViewContext, ViewHandle, WeakViewHandle, }; -use language::{language_settings::SoftWrap, LanguageRegistry}; +use language::LanguageRegistry; use menu::Confirm; +use message_editor::MessageEditor; use project::Fs; use rich_text::RichText; use serde::{Deserialize, Serialize}; @@ -31,6 +33,8 @@ use workspace::{ Workspace, }; +mod message_editor; + const MESSAGE_LOADING_THRESHOLD: usize = 50; const CHAT_PANEL_KEY: &'static str = "ChatPanel"; @@ -40,7 +44,7 @@ pub struct ChatPanel { languages: Arc, active_chat: Option<(ModelHandle, Subscription)>, message_list: ListState, - input_editor: ViewHandle, + input_editor: ViewHandle, channel_select: ViewHandle