From bc63fca8d716794c107f6181d48379452720da3c Mon Sep 17 00:00:00 2001 From: Max Brunsfeld Date: Fri, 27 Aug 2021 14:58:28 -0700 Subject: [PATCH] Fetch older messages when scrolling up in the chat message list Co-Authored-By: Nathan Sobo --- gpui/src/elements/list.rs | 14 +++ server/src/admin.rs | 6 +- server/src/db.rs | 24 ++++- server/src/rpc.rs | 67 ++++++++++++- zed/src/channel.rs | 203 +++++++++++++++++++++++++++----------- zed/src/chat_panel.rs | 17 +++- zrpc/proto/zed.proto | 13 +++ zrpc/src/proto.rs | 3 + 8 files changed, 280 insertions(+), 67 deletions(-) diff --git a/gpui/src/elements/list.rs b/gpui/src/elements/list.rs index b6773091c0b1c319280e08b584d6a823c3ddeef6..c7950482c379e0fdfc90e9fa93cd5980ab976197 100644 --- a/gpui/src/elements/list.rs +++ b/gpui/src/elements/list.rs @@ -29,6 +29,7 @@ struct StateInner { heights: SumTree, scroll_position: f32, orientation: Orientation, + scroll_handler: Option, &mut EventContext)>>, } #[derive(Clone, Debug)] @@ -272,6 +273,7 @@ impl ListState { heights, scroll_position: 0., orientation, + scroll_handler: None, }))) } @@ -290,6 +292,13 @@ impl ListState { drop(old_heights); state.heights = new_heights; } + + pub fn set_scroll_handler( + &mut self, + handler: impl FnMut(Range, &mut EventContext) + 'static, + ) { + self.0.borrow_mut().scroll_handler = Some(Box::new(handler)) + } } impl StateInner { @@ -320,6 +329,11 @@ impl StateInner { Orientation::Bottom => delta.y(), }; self.scroll_position = (self.scroll_position + delta_y).max(0.).min(scroll_max); + + if self.scroll_handler.is_some() { + let range = self.visible_range(height); + self.scroll_handler.as_mut().unwrap()(range, cx); + } cx.notify(); true diff --git a/server/src/admin.rs b/server/src/admin.rs index d6e3f8161589e6b2420cc9a7d54bd212ec8d76b1..47b29b5d0294168be720749e94e4f8ed838b802e 100644 --- a/server/src/admin.rs +++ b/server/src/admin.rs @@ -85,7 +85,7 @@ async fn post_user(mut request: Request) -> tide::Result { async fn put_user(mut request: Request) -> tide::Result { request.require_admin().await?; - let user_id = request.param("id")?.parse::()?; + let user_id = request.param("id")?.parse()?; #[derive(Deserialize)] struct Body { @@ -104,14 +104,14 @@ async fn put_user(mut request: Request) -> tide::Result { async fn delete_user(request: Request) -> tide::Result { request.require_admin().await?; - let user_id = db::UserId(request.param("id")?.parse::()?); + let user_id = db::UserId(request.param("id")?.parse()?); request.db().delete_user(user_id).await?; Ok(tide::Redirect::new("/admin").into()) } async fn delete_signup(request: Request) -> tide::Result { request.require_admin().await?; - let signup_id = db::SignupId(request.param("id")?.parse::()?); + let signup_id = db::SignupId(request.param("id")?.parse()?); request.db().delete_signup(signup_id).await?; Ok(tide::Redirect::new("/admin").into()) } diff --git a/server/src/db.rs b/server/src/db.rs index 2f1cbc5fba181457d823aa1a5a98c7bc0cbafc67..f61f5d82b44332ff452c1016239dbe366f840620 100644 --- a/server/src/db.rs +++ b/server/src/db.rs @@ -380,6 +380,7 @@ impl Db { &self, channel_id: ChannelId, count: usize, + before_id: Option, ) -> Result> { test_support!(self, { let query = r#" @@ -389,14 +390,16 @@ impl Db { FROM channel_messages WHERE - channel_id = $1 + channel_id = $1 AND + id < $2 ORDER BY id DESC - LIMIT $2 + LIMIT $3 ) as recent_messages ORDER BY id ASC "#; sqlx::query_as(query) .bind(channel_id.0) + .bind(before_id.unwrap_or(MessageId::MAX)) .bind(count as i64) .fetch_all(&self.pool) .await @@ -412,6 +415,9 @@ macro_rules! id_type { pub struct $name(pub i32); impl $name { + #[allow(unused)] + pub const MAX: Self = Self(i32::MAX); + #[allow(unused)] pub fn from_proto(value: u64) -> Self { Self(value as i32) @@ -512,10 +518,22 @@ pub mod tests { .unwrap(); } - let messages = db.get_recent_channel_messages(channel, 5).await.unwrap(); + let messages = db + .get_recent_channel_messages(channel, 5, None) + .await + .unwrap(); assert_eq!( messages.iter().map(|m| &m.body).collect::>(), ["5", "6", "7", "8", "9"] ); + + let prev_messages = db + .get_recent_channel_messages(channel, 4, Some(messages[0].id)) + .await + .unwrap(); + assert_eq!( + prev_messages.iter().map(|m| &m.body).collect::>(), + ["1", "2", "3", "4"] + ); } } diff --git a/server/src/rpc.rs b/server/src/rpc.rs index 00464450fab56b2dcde76afdccdfce083e250504..b74b923b0f2699fcf27384d8f83a99db659adc72 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -1,6 +1,6 @@ use super::{ auth, - db::{ChannelId, UserId}, + db::{ChannelId, MessageId, UserId}, AppState, }; use anyhow::anyhow; @@ -77,6 +77,8 @@ struct Channel { connection_ids: HashSet, } +const MESSAGE_COUNT_PER_PAGE: usize = 50; + impl Server { pub fn new( app_state: Arc, @@ -105,7 +107,8 @@ impl Server { .add_handler(Server::get_users) .add_handler(Server::join_channel) .add_handler(Server::leave_channel) - .add_handler(Server::send_channel_message); + .add_handler(Server::send_channel_message) + .add_handler(Server::get_channel_messages); Arc::new(server) } @@ -592,7 +595,7 @@ impl Server { let messages = self .app_state .db - .get_recent_channel_messages(channel_id, 50) + .get_recent_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None) .await? .into_iter() .map(|msg| proto::ChannelMessage { @@ -601,9 +604,15 @@ impl Server { timestamp: msg.sent_at.unix_timestamp() as u64, sender_id: msg.sender_id.to_proto(), }) - .collect(); + .collect::>(); self.peer - .respond(request.receipt(), proto::JoinChannelResponse { messages }) + .respond( + request.receipt(), + proto::JoinChannelResponse { + done: messages.len() < MESSAGE_COUNT_PER_PAGE, + messages, + }, + ) .await?; Ok(()) } @@ -685,6 +694,54 @@ impl Server { Ok(()) } + async fn get_channel_messages( + self: Arc, + request: TypedEnvelope, + ) -> tide::Result<()> { + let user_id = self + .state + .read() + .await + .user_id_for_connection(request.sender_id)?; + let channel_id = ChannelId::from_proto(request.payload.channel_id); + if !self + .app_state + .db + .can_user_access_channel(user_id, channel_id) + .await? + { + Err(anyhow!("access denied"))?; + } + + let messages = self + .app_state + .db + .get_recent_channel_messages( + channel_id, + MESSAGE_COUNT_PER_PAGE, + Some(MessageId::from_proto(request.payload.before_message_id)), + ) + .await? + .into_iter() + .map(|msg| proto::ChannelMessage { + id: msg.id.to_proto(), + body: msg.body, + timestamp: msg.sent_at.unix_timestamp() as u64, + sender_id: msg.sender_id.to_proto(), + }) + .collect::>(); + self.peer + .respond( + request.receipt(), + proto::GetChannelMessagesResponse { + done: messages.len() < MESSAGE_COUNT_PER_PAGE, + messages, + }, + ) + .await?; + Ok(()) + } + async fn broadcast_in_worktree( &self, worktree_id: u64, diff --git a/zed/src/channel.rs b/zed/src/channel.rs index f1f2e076077e5b38a7a036fab679d6b880d972d4..f0f9d4f43c94fd1e818e0b2a3021e1c4eb173e94 100644 --- a/zed/src/channel.rs +++ b/zed/src/channel.rs @@ -37,6 +37,7 @@ pub struct ChannelDetails { pub struct Channel { details: ChannelDetails, messages: SumTree, + loaded_all_messages: bool, pending_messages: Vec, next_local_message_id: u64, user_store: Arc, @@ -70,7 +71,7 @@ pub enum ChannelListEvent {} #[derive(Clone, Debug, PartialEq)] pub enum ChannelEvent { - Message { + MessagesAdded { old_range: Range, new_count: usize, }, @@ -192,31 +193,12 @@ impl Channel { cx.spawn(|channel, mut cx| { async move { let response = rpc.request(proto::JoinChannel { channel_id }).await?; - - let unique_user_ids = response - .messages - .iter() - .map(|m| m.sender_id) - .collect::>() - .into_iter() - .collect(); - user_store.load_users(unique_user_ids).await?; - - let mut messages = Vec::with_capacity(response.messages.len()); - for message in response.messages { - messages.push(ChannelMessage::from_proto(message, &user_store).await?); - } + let messages = messages_from_proto(response.messages, &user_store).await?; + let loaded_all_messages = response.done; channel.update(&mut cx, |channel, cx| { - let old_count = channel.messages.summary().count; - let new_count = messages.len(); - - channel.messages = SumTree::new(); - channel.messages.extend(messages, &()); - cx.emit(ChannelEvent::Message { - old_range: 0..old_count, - new_count, - }); + channel.insert_messages(messages, cx); + channel.loaded_all_messages = loaded_all_messages; }); Ok(()) @@ -232,6 +214,7 @@ impl Channel { rpc, messages: Default::default(), pending_messages: Default::default(), + loaded_all_messages: false, next_local_message_id: 0, _subscription, } @@ -264,15 +247,18 @@ impl Channel { .binary_search_by_key(&local_id, |msg| msg.local_id) { let body = this.pending_messages.remove(i).body; - this.insert_message( - ChannelMessage { - id: response.message_id, - timestamp: OffsetDateTime::from_unix_timestamp( - response.timestamp as i64, - )?, - body, - sender, - }, + this.insert_messages( + SumTree::from_item( + ChannelMessage { + id: response.message_id, + timestamp: OffsetDateTime::from_unix_timestamp( + response.timestamp as i64, + )?, + body, + sender, + }, + &(), + ), cx, ); } @@ -286,6 +272,37 @@ impl Channel { Ok(()) } + pub fn load_more_messages(&mut self, cx: &mut ModelContext) -> bool { + if !self.loaded_all_messages { + let rpc = self.rpc.clone(); + let user_store = self.user_store.clone(); + let channel_id = self.details.id; + if let Some(before_message_id) = self.messages.first().map(|message| message.id) { + cx.spawn(|this, mut cx| { + async move { + let response = rpc + .request(proto::GetChannelMessages { + channel_id, + before_message_id, + }) + .await?; + let loaded_all_messages = response.done; + let messages = messages_from_proto(response.messages, &user_store).await?; + this.update(&mut cx, |this, cx| { + this.loaded_all_messages = loaded_all_messages; + this.insert_messages(messages, cx); + }); + Ok(()) + } + .log_err() + }) + .detach(); + return true; + } + } + false + } + pub fn message_count(&self) -> usize { self.messages.summary().count } @@ -326,7 +343,9 @@ impl Channel { cx.spawn(|this, mut cx| { async move { let message = ChannelMessage::from_proto(message, &user_store).await?; - this.update(&mut cx, |this, cx| this.insert_message(message, cx)); + this.update(&mut cx, |this, cx| { + this.insert_messages(SumTree::from_item(message, &()), cx) + }); Ok(()) } .log_err() @@ -335,27 +354,49 @@ impl Channel { Ok(()) } - fn insert_message(&mut self, message: ChannelMessage, cx: &mut ModelContext) { - let mut old_cursor = self.messages.cursor::(); - let mut new_messages = old_cursor.slice(&message.id, Bias::Left, &()); - let start_ix = old_cursor.sum_start().0; - let mut end_ix = start_ix; - if old_cursor.item().map_or(false, |m| m.id == message.id) { - old_cursor.next(&()); - end_ix += 1; + fn insert_messages(&mut self, messages: SumTree, cx: &mut ModelContext) { + if let Some((first_message, last_message)) = messages.first().zip(messages.last()) { + let mut old_cursor = self.messages.cursor::(); + let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &()); + let start_ix = old_cursor.sum_start().0; + let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &()); + let removed_count = removed_messages.summary().count; + let new_count = messages.summary().count; + let end_ix = start_ix + removed_count; + + new_messages.push_tree(messages, &()); + new_messages.push_tree(old_cursor.suffix(&()), &()); + drop(old_cursor); + self.messages = new_messages; + + cx.emit(ChannelEvent::MessagesAdded { + old_range: start_ix..end_ix, + new_count, + }); + cx.notify(); } + } +} - new_messages.push(message.clone(), &()); - new_messages.push_tree(old_cursor.suffix(&()), &()); - drop(old_cursor); - self.messages = new_messages; - - cx.emit(ChannelEvent::Message { - old_range: start_ix..end_ix, - new_count: 1, - }); - cx.notify(); +async fn messages_from_proto( + proto_messages: Vec, + user_store: &UserStore, +) -> Result> { + let unique_user_ids = proto_messages + .iter() + .map(|m| m.sender_id) + .collect::>() + .into_iter() + .collect(); + user_store.load_users(unique_user_ids).await?; + + let mut messages = Vec::with_capacity(proto_messages.len()); + for message in proto_messages { + messages.push(ChannelMessage::from_proto(message, &user_store).await?); } + let mut result = SumTree::new(); + result.extend(messages, &()); + Ok(result) } impl From for ChannelDetails { @@ -489,9 +530,11 @@ mod tests { sender_id: 6, }, ], + done: false, }, ) .await; + // Client requests all users for the received messages let mut get_users = server.receive::().await; get_users.payload.user_ids.sort(); @@ -518,7 +561,7 @@ mod tests { assert_eq!( channel.next_event(&cx).await, - ChannelEvent::Message { + ChannelEvent::MessagesAdded { old_range: 0..0, new_count: 2, } @@ -567,7 +610,7 @@ mod tests { assert_eq!( channel.next_event(&cx).await, - ChannelEvent::Message { + ChannelEvent::MessagesAdded { old_range: 2..2, new_count: 1, } @@ -580,7 +623,57 @@ mod tests { .collect::>(), &[("as-cii".into(), "c".into())] ) - }) + }); + + // Scroll up to view older messages. + channel.update(&mut cx, |channel, cx| { + assert!(channel.load_more_messages(cx)); + }); + let get_messages = server.receive::().await; + assert_eq!(get_messages.payload.channel_id, 5); + assert_eq!(get_messages.payload.before_message_id, 10); + server + .respond( + get_messages.receipt(), + proto::GetChannelMessagesResponse { + done: true, + messages: vec![ + proto::ChannelMessage { + id: 8, + body: "y".into(), + timestamp: 998, + sender_id: 5, + }, + proto::ChannelMessage { + id: 9, + body: "z".into(), + timestamp: 999, + sender_id: 6, + }, + ], + }, + ) + .await; + + assert_eq!( + channel.next_event(&cx).await, + ChannelEvent::MessagesAdded { + old_range: 0..0, + new_count: 2, + } + ); + channel.read_with(&cx, |channel, _| { + assert_eq!( + channel + .messages_in_range(0..2) + .map(|message| (message.sender.github_login.clone(), message.body.clone())) + .collect::>(), + &[ + ("nathansobo".into(), "y".into()), + ("maxbrunsfeld".into(), "z".into()) + ] + ); + }); } struct FakeServer { diff --git a/zed/src/chat_panel.rs b/zed/src/chat_panel.rs index 37409800feaad2beb810cc87dff5e5df5e062a7b..a8b4ed94bf790154f5117a74437c820d0a33de0d 100644 --- a/zed/src/chat_panel.rs +++ b/zed/src/chat_panel.rs @@ -22,9 +22,11 @@ pub struct ChatPanel { pub enum Event {} action!(Send); +action!(LoadMoreMessages); pub fn init(cx: &mut MutableAppContext) { cx.add_action(ChatPanel::send); + cx.add_action(ChatPanel::load_more_messages); cx.add_bindings(vec![Binding::new("enter", Send, Some("ChatPanel"))]); } @@ -78,6 +80,11 @@ impl ChatPanel { let subscription = cx.subscribe(&channel, Self::channel_did_change); self.message_list = ListState::new(channel.read(cx).message_count(), Orientation::Bottom); + self.message_list.set_scroll_handler(|visible_range, cx| { + if visible_range.start < 5 { + cx.dispatch_action(LoadMoreMessages); + } + }); self.active_channel = Some((channel, subscription)); } } @@ -89,7 +96,7 @@ impl ChatPanel { cx: &mut ViewContext, ) { match event { - ChannelEvent::Message { + ChannelEvent::MessagesAdded { old_range, new_count, } => { @@ -191,6 +198,14 @@ impl ChatPanel { .log_err(); } } + + fn load_more_messages(&mut self, _: &LoadMoreMessages, cx: &mut ViewContext) { + if let Some((channel, _)) = self.active_channel.as_ref() { + channel.update(cx, |channel, cx| { + channel.load_more_messages(cx); + }) + } + } } impl Entity for ChatPanel { diff --git a/zrpc/proto/zed.proto b/zrpc/proto/zed.proto index f368cb2d475da9dab1e651a8d29bc961dbec1bb8..a94c0f62049d18dcc1bf04463c3cba1b0b33e219 100644 --- a/zrpc/proto/zed.proto +++ b/zrpc/proto/zed.proto @@ -32,6 +32,8 @@ message Envelope { SendChannelMessage send_channel_message = 27; SendChannelMessageResponse send_channel_message_response = 28; ChannelMessageSent channel_message_sent = 29; + GetChannelMessages get_channel_messages = 30; + GetChannelMessagesResponse get_channel_messages_response = 31; } } @@ -130,6 +132,7 @@ message JoinChannel { message JoinChannelResponse { repeated ChannelMessage messages = 1; + bool done = 2; } message LeaveChannel { @@ -159,6 +162,16 @@ message ChannelMessageSent { ChannelMessage message = 2; } +message GetChannelMessages { + uint64 channel_id = 1; + uint64 before_message_id = 2; +} + +message GetChannelMessagesResponse { + repeated ChannelMessage messages = 1; + bool done = 2; +} + // Entities message Peer { diff --git a/zrpc/src/proto.rs b/zrpc/src/proto.rs index 330e3afa48380629dc0d78dfbd76943ee813bfdb..3743a06a07d56a496a97084fb9cc839e1e7afce9 100644 --- a/zrpc/src/proto.rs +++ b/zrpc/src/proto.rs @@ -125,6 +125,8 @@ messages!( ChannelMessageSent, CloseBuffer, CloseWorktree, + GetChannelMessages, + GetChannelMessagesResponse, GetChannels, GetChannelsResponse, GetUsers, @@ -158,6 +160,7 @@ request_messages!( (SaveBuffer, BufferSaved), (ShareWorktree, ShareWorktreeResponse), (SendChannelMessage, SendChannelMessageResponse), + (GetChannelMessages, GetChannelMessagesResponse), ); entity_messages!(