Fetch older messages when scrolling up in the chat message list

Max Brunsfeld and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

gpui/src/elements/list.rs |  14 ++
server/src/admin.rs       |   6 
server/src/db.rs          |  24 ++++
server/src/rpc.rs         |  67 ++++++++++++-
zed/src/channel.rs        | 203 +++++++++++++++++++++++++++++-----------
zed/src/chat_panel.rs     |  17 +++
zrpc/proto/zed.proto      |  13 ++
zrpc/src/proto.rs         |   3 
8 files changed, 280 insertions(+), 67 deletions(-)

Detailed changes

gpui/src/elements/list.rs 🔗

@@ -29,6 +29,7 @@ struct StateInner {
     heights: SumTree<ElementHeight>,
     scroll_position: f32,
     orientation: Orientation,
+    scroll_handler: Option<Box<dyn FnMut(Range<usize>, &mut EventContext)>>,
 }
 
 #[derive(Clone, Debug)]
@@ -272,6 +273,7 @@ impl ListState {
             heights,
             scroll_position: 0.,
             orientation,
+            scroll_handler: None,
         })))
     }
 
@@ -290,6 +292,13 @@ impl ListState {
         drop(old_heights);
         state.heights = new_heights;
     }
+
+    pub fn set_scroll_handler(
+        &mut self,
+        handler: impl FnMut(Range<usize>, &mut EventContext) + 'static,
+    ) {
+        self.0.borrow_mut().scroll_handler = Some(Box::new(handler))
+    }
 }
 
 impl StateInner {
@@ -320,6 +329,11 @@ impl StateInner {
             Orientation::Bottom => delta.y(),
         };
         self.scroll_position = (self.scroll_position + delta_y).max(0.).min(scroll_max);
+
+        if self.scroll_handler.is_some() {
+            let range = self.visible_range(height);
+            self.scroll_handler.as_mut().unwrap()(range, cx);
+        }
         cx.notify();
 
         true

server/src/admin.rs 🔗

@@ -85,7 +85,7 @@ async fn post_user(mut request: Request) -> tide::Result {
 async fn put_user(mut request: Request) -> tide::Result {
     request.require_admin().await?;
 
-    let user_id = request.param("id")?.parse::<i32>()?;
+    let user_id = request.param("id")?.parse()?;
 
     #[derive(Deserialize)]
     struct Body {
@@ -104,14 +104,14 @@ async fn put_user(mut request: Request) -> tide::Result {
 
 async fn delete_user(request: Request) -> tide::Result {
     request.require_admin().await?;
-    let user_id = db::UserId(request.param("id")?.parse::<i32>()?);
+    let user_id = db::UserId(request.param("id")?.parse()?);
     request.db().delete_user(user_id).await?;
     Ok(tide::Redirect::new("/admin").into())
 }
 
 async fn delete_signup(request: Request) -> tide::Result {
     request.require_admin().await?;
-    let signup_id = db::SignupId(request.param("id")?.parse::<i32>()?);
+    let signup_id = db::SignupId(request.param("id")?.parse()?);
     request.db().delete_signup(signup_id).await?;
     Ok(tide::Redirect::new("/admin").into())
 }

server/src/db.rs 🔗

@@ -380,6 +380,7 @@ impl Db {
         &self,
         channel_id: ChannelId,
         count: usize,
+        before_id: Option<MessageId>,
     ) -> Result<Vec<ChannelMessage>> {
         test_support!(self, {
             let query = r#"
@@ -389,14 +390,16 @@ impl Db {
                     FROM
                         channel_messages
                     WHERE
-                        channel_id = $1
+                        channel_id = $1 AND
+                        id < $2
                     ORDER BY id DESC
-                    LIMIT $2
+                    LIMIT $3
                 ) as recent_messages
                 ORDER BY id ASC
             "#;
             sqlx::query_as(query)
                 .bind(channel_id.0)
+                .bind(before_id.unwrap_or(MessageId::MAX))
                 .bind(count as i64)
                 .fetch_all(&self.pool)
                 .await
@@ -412,6 +415,9 @@ macro_rules! id_type {
         pub struct $name(pub i32);
 
         impl $name {
+            #[allow(unused)]
+            pub const MAX: Self = Self(i32::MAX);
+
             #[allow(unused)]
             pub fn from_proto(value: u64) -> Self {
                 Self(value as i32)
@@ -512,10 +518,22 @@ pub mod tests {
                 .unwrap();
         }
 
-        let messages = db.get_recent_channel_messages(channel, 5).await.unwrap();
+        let messages = db
+            .get_recent_channel_messages(channel, 5, None)
+            .await
+            .unwrap();
         assert_eq!(
             messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
             ["5", "6", "7", "8", "9"]
         );
+
+        let prev_messages = db
+            .get_recent_channel_messages(channel, 4, Some(messages[0].id))
+            .await
+            .unwrap();
+        assert_eq!(
+            prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
+            ["1", "2", "3", "4"]
+        );
     }
 }

server/src/rpc.rs 🔗

@@ -1,6 +1,6 @@
 use super::{
     auth,
-    db::{ChannelId, UserId},
+    db::{ChannelId, MessageId, UserId},
     AppState,
 };
 use anyhow::anyhow;
@@ -77,6 +77,8 @@ struct Channel {
     connection_ids: HashSet<ConnectionId>,
 }
 
+const MESSAGE_COUNT_PER_PAGE: usize = 50;
+
 impl Server {
     pub fn new(
         app_state: Arc<AppState>,
@@ -105,7 +107,8 @@ impl Server {
             .add_handler(Server::get_users)
             .add_handler(Server::join_channel)
             .add_handler(Server::leave_channel)
-            .add_handler(Server::send_channel_message);
+            .add_handler(Server::send_channel_message)
+            .add_handler(Server::get_channel_messages);
 
         Arc::new(server)
     }
@@ -592,7 +595,7 @@ impl Server {
         let messages = self
             .app_state
             .db
-            .get_recent_channel_messages(channel_id, 50)
+            .get_recent_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
             .await?
             .into_iter()
             .map(|msg| proto::ChannelMessage {
@@ -601,9 +604,15 @@ impl Server {
                 timestamp: msg.sent_at.unix_timestamp() as u64,
                 sender_id: msg.sender_id.to_proto(),
             })
-            .collect();
+            .collect::<Vec<_>>();
         self.peer
-            .respond(request.receipt(), proto::JoinChannelResponse { messages })
+            .respond(
+                request.receipt(),
+                proto::JoinChannelResponse {
+                    done: messages.len() < MESSAGE_COUNT_PER_PAGE,
+                    messages,
+                },
+            )
             .await?;
         Ok(())
     }
@@ -685,6 +694,54 @@ impl Server {
         Ok(())
     }
 
+    async fn get_channel_messages(
+        self: Arc<Self>,
+        request: TypedEnvelope<proto::GetChannelMessages>,
+    ) -> tide::Result<()> {
+        let user_id = self
+            .state
+            .read()
+            .await
+            .user_id_for_connection(request.sender_id)?;
+        let channel_id = ChannelId::from_proto(request.payload.channel_id);
+        if !self
+            .app_state
+            .db
+            .can_user_access_channel(user_id, channel_id)
+            .await?
+        {
+            Err(anyhow!("access denied"))?;
+        }
+
+        let messages = self
+            .app_state
+            .db
+            .get_recent_channel_messages(
+                channel_id,
+                MESSAGE_COUNT_PER_PAGE,
+                Some(MessageId::from_proto(request.payload.before_message_id)),
+            )
+            .await?
+            .into_iter()
+            .map(|msg| proto::ChannelMessage {
+                id: msg.id.to_proto(),
+                body: msg.body,
+                timestamp: msg.sent_at.unix_timestamp() as u64,
+                sender_id: msg.sender_id.to_proto(),
+            })
+            .collect::<Vec<_>>();
+        self.peer
+            .respond(
+                request.receipt(),
+                proto::GetChannelMessagesResponse {
+                    done: messages.len() < MESSAGE_COUNT_PER_PAGE,
+                    messages,
+                },
+            )
+            .await?;
+        Ok(())
+    }
+
     async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
         &self,
         worktree_id: u64,

zed/src/channel.rs 🔗

@@ -37,6 +37,7 @@ pub struct ChannelDetails {
 pub struct Channel {
     details: ChannelDetails,
     messages: SumTree<ChannelMessage>,
+    loaded_all_messages: bool,
     pending_messages: Vec<PendingChannelMessage>,
     next_local_message_id: u64,
     user_store: Arc<UserStore>,
@@ -70,7 +71,7 @@ pub enum ChannelListEvent {}
 
 #[derive(Clone, Debug, PartialEq)]
 pub enum ChannelEvent {
-    Message {
+    MessagesAdded {
         old_range: Range<usize>,
         new_count: usize,
     },
@@ -192,31 +193,12 @@ impl Channel {
             cx.spawn(|channel, mut cx| {
                 async move {
                     let response = rpc.request(proto::JoinChannel { channel_id }).await?;
-
-                    let unique_user_ids = response
-                        .messages
-                        .iter()
-                        .map(|m| m.sender_id)
-                        .collect::<HashSet<_>>()
-                        .into_iter()
-                        .collect();
-                    user_store.load_users(unique_user_ids).await?;
-
-                    let mut messages = Vec::with_capacity(response.messages.len());
-                    for message in response.messages {
-                        messages.push(ChannelMessage::from_proto(message, &user_store).await?);
-                    }
+                    let messages = messages_from_proto(response.messages, &user_store).await?;
+                    let loaded_all_messages = response.done;
 
                     channel.update(&mut cx, |channel, cx| {
-                        let old_count = channel.messages.summary().count;
-                        let new_count = messages.len();
-
-                        channel.messages = SumTree::new();
-                        channel.messages.extend(messages, &());
-                        cx.emit(ChannelEvent::Message {
-                            old_range: 0..old_count,
-                            new_count,
-                        });
+                        channel.insert_messages(messages, cx);
+                        channel.loaded_all_messages = loaded_all_messages;
                     });
 
                     Ok(())
@@ -232,6 +214,7 @@ impl Channel {
             rpc,
             messages: Default::default(),
             pending_messages: Default::default(),
+            loaded_all_messages: false,
             next_local_message_id: 0,
             _subscription,
         }
@@ -264,15 +247,18 @@ impl Channel {
                         .binary_search_by_key(&local_id, |msg| msg.local_id)
                     {
                         let body = this.pending_messages.remove(i).body;
-                        this.insert_message(
-                            ChannelMessage {
-                                id: response.message_id,
-                                timestamp: OffsetDateTime::from_unix_timestamp(
-                                    response.timestamp as i64,
-                                )?,
-                                body,
-                                sender,
-                            },
+                        this.insert_messages(
+                            SumTree::from_item(
+                                ChannelMessage {
+                                    id: response.message_id,
+                                    timestamp: OffsetDateTime::from_unix_timestamp(
+                                        response.timestamp as i64,
+                                    )?,
+                                    body,
+                                    sender,
+                                },
+                                &(),
+                            ),
                             cx,
                         );
                     }
@@ -286,6 +272,37 @@ impl Channel {
         Ok(())
     }
 
+    pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> bool {
+        if !self.loaded_all_messages {
+            let rpc = self.rpc.clone();
+            let user_store = self.user_store.clone();
+            let channel_id = self.details.id;
+            if let Some(before_message_id) = self.messages.first().map(|message| message.id) {
+                cx.spawn(|this, mut cx| {
+                    async move {
+                        let response = rpc
+                            .request(proto::GetChannelMessages {
+                                channel_id,
+                                before_message_id,
+                            })
+                            .await?;
+                        let loaded_all_messages = response.done;
+                        let messages = messages_from_proto(response.messages, &user_store).await?;
+                        this.update(&mut cx, |this, cx| {
+                            this.loaded_all_messages = loaded_all_messages;
+                            this.insert_messages(messages, cx);
+                        });
+                        Ok(())
+                    }
+                    .log_err()
+                })
+                .detach();
+                return true;
+            }
+        }
+        false
+    }
+
     pub fn message_count(&self) -> usize {
         self.messages.summary().count
     }
@@ -326,7 +343,9 @@ impl Channel {
         cx.spawn(|this, mut cx| {
             async move {
                 let message = ChannelMessage::from_proto(message, &user_store).await?;
-                this.update(&mut cx, |this, cx| this.insert_message(message, cx));
+                this.update(&mut cx, |this, cx| {
+                    this.insert_messages(SumTree::from_item(message, &()), cx)
+                });
                 Ok(())
             }
             .log_err()
@@ -335,27 +354,49 @@ impl Channel {
         Ok(())
     }
 
-    fn insert_message(&mut self, message: ChannelMessage, cx: &mut ModelContext<Self>) {
-        let mut old_cursor = self.messages.cursor::<u64, Count>();
-        let mut new_messages = old_cursor.slice(&message.id, Bias::Left, &());
-        let start_ix = old_cursor.sum_start().0;
-        let mut end_ix = start_ix;
-        if old_cursor.item().map_or(false, |m| m.id == message.id) {
-            old_cursor.next(&());
-            end_ix += 1;
+    fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
+        if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
+            let mut old_cursor = self.messages.cursor::<u64, Count>();
+            let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
+            let start_ix = old_cursor.sum_start().0;
+            let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
+            let removed_count = removed_messages.summary().count;
+            let new_count = messages.summary().count;
+            let end_ix = start_ix + removed_count;
+
+            new_messages.push_tree(messages, &());
+            new_messages.push_tree(old_cursor.suffix(&()), &());
+            drop(old_cursor);
+            self.messages = new_messages;
+
+            cx.emit(ChannelEvent::MessagesAdded {
+                old_range: start_ix..end_ix,
+                new_count,
+            });
+            cx.notify();
         }
+    }
+}
 
-        new_messages.push(message.clone(), &());
-        new_messages.push_tree(old_cursor.suffix(&()), &());
-        drop(old_cursor);
-        self.messages = new_messages;
-
-        cx.emit(ChannelEvent::Message {
-            old_range: start_ix..end_ix,
-            new_count: 1,
-        });
-        cx.notify();
+async fn messages_from_proto(
+    proto_messages: Vec<proto::ChannelMessage>,
+    user_store: &UserStore,
+) -> Result<SumTree<ChannelMessage>> {
+    let unique_user_ids = proto_messages
+        .iter()
+        .map(|m| m.sender_id)
+        .collect::<HashSet<_>>()
+        .into_iter()
+        .collect();
+    user_store.load_users(unique_user_ids).await?;
+
+    let mut messages = Vec::with_capacity(proto_messages.len());
+    for message in proto_messages {
+        messages.push(ChannelMessage::from_proto(message, &user_store).await?);
     }
+    let mut result = SumTree::new();
+    result.extend(messages, &());
+    Ok(result)
 }
 
 impl From<proto::Channel> for ChannelDetails {
@@ -489,9 +530,11 @@ mod tests {
                             sender_id: 6,
                         },
                     ],
+                    done: false,
                 },
             )
             .await;
+
         // Client requests all users for the received messages
         let mut get_users = server.receive::<proto::GetUsers>().await;
         get_users.payload.user_ids.sort();
@@ -518,7 +561,7 @@ mod tests {
 
         assert_eq!(
             channel.next_event(&cx).await,
-            ChannelEvent::Message {
+            ChannelEvent::MessagesAdded {
                 old_range: 0..0,
                 new_count: 2,
             }
@@ -567,7 +610,7 @@ mod tests {
 
         assert_eq!(
             channel.next_event(&cx).await,
-            ChannelEvent::Message {
+            ChannelEvent::MessagesAdded {
                 old_range: 2..2,
                 new_count: 1,
             }
@@ -580,7 +623,57 @@ mod tests {
                     .collect::<Vec<_>>(),
                 &[("as-cii".into(), "c".into())]
             )
-        })
+        });
+
+        // Scroll up to view older messages.
+        channel.update(&mut cx, |channel, cx| {
+            assert!(channel.load_more_messages(cx));
+        });
+        let get_messages = server.receive::<proto::GetChannelMessages>().await;
+        assert_eq!(get_messages.payload.channel_id, 5);
+        assert_eq!(get_messages.payload.before_message_id, 10);
+        server
+            .respond(
+                get_messages.receipt(),
+                proto::GetChannelMessagesResponse {
+                    done: true,
+                    messages: vec![
+                        proto::ChannelMessage {
+                            id: 8,
+                            body: "y".into(),
+                            timestamp: 998,
+                            sender_id: 5,
+                        },
+                        proto::ChannelMessage {
+                            id: 9,
+                            body: "z".into(),
+                            timestamp: 999,
+                            sender_id: 6,
+                        },
+                    ],
+                },
+            )
+            .await;
+
+        assert_eq!(
+            channel.next_event(&cx).await,
+            ChannelEvent::MessagesAdded {
+                old_range: 0..0,
+                new_count: 2,
+            }
+        );
+        channel.read_with(&cx, |channel, _| {
+            assert_eq!(
+                channel
+                    .messages_in_range(0..2)
+                    .map(|message| (message.sender.github_login.clone(), message.body.clone()))
+                    .collect::<Vec<_>>(),
+                &[
+                    ("nathansobo".into(), "y".into()),
+                    ("maxbrunsfeld".into(), "z".into())
+                ]
+            );
+        });
     }
 
     struct FakeServer {

zed/src/chat_panel.rs 🔗

@@ -22,9 +22,11 @@ pub struct ChatPanel {
 pub enum Event {}
 
 action!(Send);
+action!(LoadMoreMessages);
 
 pub fn init(cx: &mut MutableAppContext) {
     cx.add_action(ChatPanel::send);
+    cx.add_action(ChatPanel::load_more_messages);
 
     cx.add_bindings(vec![Binding::new("enter", Send, Some("ChatPanel"))]);
 }
@@ -78,6 +80,11 @@ impl ChatPanel {
             let subscription = cx.subscribe(&channel, Self::channel_did_change);
             self.message_list =
                 ListState::new(channel.read(cx).message_count(), Orientation::Bottom);
+            self.message_list.set_scroll_handler(|visible_range, cx| {
+                if visible_range.start < 5 {
+                    cx.dispatch_action(LoadMoreMessages);
+                }
+            });
             self.active_channel = Some((channel, subscription));
         }
     }
@@ -89,7 +96,7 @@ impl ChatPanel {
         cx: &mut ViewContext<Self>,
     ) {
         match event {
-            ChannelEvent::Message {
+            ChannelEvent::MessagesAdded {
                 old_range,
                 new_count,
             } => {
@@ -191,6 +198,14 @@ impl ChatPanel {
                 .log_err();
         }
     }
+
+    fn load_more_messages(&mut self, _: &LoadMoreMessages, cx: &mut ViewContext<Self>) {
+        if let Some((channel, _)) = self.active_channel.as_ref() {
+            channel.update(cx, |channel, cx| {
+                channel.load_more_messages(cx);
+            })
+        }
+    }
 }
 
 impl Entity for ChatPanel {

zrpc/proto/zed.proto 🔗

@@ -32,6 +32,8 @@ message Envelope {
         SendChannelMessage send_channel_message = 27;
         SendChannelMessageResponse send_channel_message_response = 28;
         ChannelMessageSent channel_message_sent = 29;
+        GetChannelMessages get_channel_messages = 30;
+        GetChannelMessagesResponse get_channel_messages_response = 31;
     }
 }
 
@@ -130,6 +132,7 @@ message JoinChannel {
 
 message JoinChannelResponse {
     repeated ChannelMessage messages = 1;
+    bool done = 2;
 }
 
 message LeaveChannel {
@@ -159,6 +162,16 @@ message ChannelMessageSent {
     ChannelMessage message = 2;
 }
 
+message GetChannelMessages {
+    uint64 channel_id = 1;
+    uint64 before_message_id = 2;
+}
+
+message GetChannelMessagesResponse {
+    repeated ChannelMessage messages = 1;
+    bool done = 2;
+}
+
 // Entities
 
 message Peer {

zrpc/src/proto.rs 🔗

@@ -125,6 +125,8 @@ messages!(
     ChannelMessageSent,
     CloseBuffer,
     CloseWorktree,
+    GetChannelMessages,
+    GetChannelMessagesResponse,
     GetChannels,
     GetChannelsResponse,
     GetUsers,
@@ -158,6 +160,7 @@ request_messages!(
     (SaveBuffer, BufferSaved),
     (ShareWorktree, ShareWorktreeResponse),
     (SendChannelMessage, SendChannelMessageResponse),
+    (GetChannelMessages, GetChannelMessagesResponse),
 );
 
 entity_messages!(