Detailed changes
@@ -29,6 +29,7 @@ struct StateInner {
heights: SumTree<ElementHeight>,
scroll_position: f32,
orientation: Orientation,
+ scroll_handler: Option<Box<dyn FnMut(Range<usize>, &mut EventContext)>>,
}
#[derive(Clone, Debug)]
@@ -272,6 +273,7 @@ impl ListState {
heights,
scroll_position: 0.,
orientation,
+ scroll_handler: None,
})))
}
@@ -290,6 +292,13 @@ impl ListState {
drop(old_heights);
state.heights = new_heights;
}
+
+ pub fn set_scroll_handler(
+ &mut self,
+ handler: impl FnMut(Range<usize>, &mut EventContext) + 'static,
+ ) {
+ self.0.borrow_mut().scroll_handler = Some(Box::new(handler))
+ }
}
impl StateInner {
@@ -320,6 +329,11 @@ impl StateInner {
Orientation::Bottom => delta.y(),
};
self.scroll_position = (self.scroll_position + delta_y).max(0.).min(scroll_max);
+
+ if self.scroll_handler.is_some() {
+ let range = self.visible_range(height);
+ self.scroll_handler.as_mut().unwrap()(range, cx);
+ }
cx.notify();
true
@@ -85,7 +85,7 @@ async fn post_user(mut request: Request) -> tide::Result {
async fn put_user(mut request: Request) -> tide::Result {
request.require_admin().await?;
- let user_id = request.param("id")?.parse::<i32>()?;
+ let user_id = request.param("id")?.parse()?;
#[derive(Deserialize)]
struct Body {
@@ -104,14 +104,14 @@ async fn put_user(mut request: Request) -> tide::Result {
async fn delete_user(request: Request) -> tide::Result {
request.require_admin().await?;
- let user_id = db::UserId(request.param("id")?.parse::<i32>()?);
+ let user_id = db::UserId(request.param("id")?.parse()?);
request.db().delete_user(user_id).await?;
Ok(tide::Redirect::new("/admin").into())
}
async fn delete_signup(request: Request) -> tide::Result {
request.require_admin().await?;
- let signup_id = db::SignupId(request.param("id")?.parse::<i32>()?);
+ let signup_id = db::SignupId(request.param("id")?.parse()?);
request.db().delete_signup(signup_id).await?;
Ok(tide::Redirect::new("/admin").into())
}
@@ -380,6 +380,7 @@ impl Db {
&self,
channel_id: ChannelId,
count: usize,
+ before_id: Option<MessageId>,
) -> Result<Vec<ChannelMessage>> {
test_support!(self, {
let query = r#"
@@ -389,14 +390,16 @@ impl Db {
FROM
channel_messages
WHERE
- channel_id = $1
+ channel_id = $1 AND
+ id < $2
ORDER BY id DESC
- LIMIT $2
+ LIMIT $3
) as recent_messages
ORDER BY id ASC
"#;
sqlx::query_as(query)
.bind(channel_id.0)
+ .bind(before_id.unwrap_or(MessageId::MAX))
.bind(count as i64)
.fetch_all(&self.pool)
.await
@@ -412,6 +415,9 @@ macro_rules! id_type {
pub struct $name(pub i32);
impl $name {
+ #[allow(unused)]
+ pub const MAX: Self = Self(i32::MAX);
+
#[allow(unused)]
pub fn from_proto(value: u64) -> Self {
Self(value as i32)
@@ -512,10 +518,22 @@ pub mod tests {
.unwrap();
}
- let messages = db.get_recent_channel_messages(channel, 5).await.unwrap();
+ let messages = db
+ .get_recent_channel_messages(channel, 5, None)
+ .await
+ .unwrap();
assert_eq!(
messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
["5", "6", "7", "8", "9"]
);
+
+ let prev_messages = db
+ .get_recent_channel_messages(channel, 4, Some(messages[0].id))
+ .await
+ .unwrap();
+ assert_eq!(
+ prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
+ ["1", "2", "3", "4"]
+ );
}
}
@@ -1,6 +1,6 @@
use super::{
auth,
- db::{ChannelId, UserId},
+ db::{ChannelId, MessageId, UserId},
AppState,
};
use anyhow::anyhow;
@@ -77,6 +77,8 @@ struct Channel {
connection_ids: HashSet<ConnectionId>,
}
+const MESSAGE_COUNT_PER_PAGE: usize = 50;
+
impl Server {
pub fn new(
app_state: Arc<AppState>,
@@ -105,7 +107,8 @@ impl Server {
.add_handler(Server::get_users)
.add_handler(Server::join_channel)
.add_handler(Server::leave_channel)
- .add_handler(Server::send_channel_message);
+ .add_handler(Server::send_channel_message)
+ .add_handler(Server::get_channel_messages);
Arc::new(server)
}
@@ -592,7 +595,7 @@ impl Server {
let messages = self
.app_state
.db
- .get_recent_channel_messages(channel_id, 50)
+ .get_recent_channel_messages(channel_id, MESSAGE_COUNT_PER_PAGE, None)
.await?
.into_iter()
.map(|msg| proto::ChannelMessage {
@@ -601,9 +604,15 @@ impl Server {
timestamp: msg.sent_at.unix_timestamp() as u64,
sender_id: msg.sender_id.to_proto(),
})
- .collect();
+ .collect::<Vec<_>>();
self.peer
- .respond(request.receipt(), proto::JoinChannelResponse { messages })
+ .respond(
+ request.receipt(),
+ proto::JoinChannelResponse {
+ done: messages.len() < MESSAGE_COUNT_PER_PAGE,
+ messages,
+ },
+ )
.await?;
Ok(())
}
@@ -685,6 +694,54 @@ impl Server {
Ok(())
}
+ async fn get_channel_messages(
+ self: Arc<Self>,
+ request: TypedEnvelope<proto::GetChannelMessages>,
+ ) -> tide::Result<()> {
+ let user_id = self
+ .state
+ .read()
+ .await
+ .user_id_for_connection(request.sender_id)?;
+ let channel_id = ChannelId::from_proto(request.payload.channel_id);
+ if !self
+ .app_state
+ .db
+ .can_user_access_channel(user_id, channel_id)
+ .await?
+ {
+ Err(anyhow!("access denied"))?;
+ }
+
+ let messages = self
+ .app_state
+ .db
+ .get_recent_channel_messages(
+ channel_id,
+ MESSAGE_COUNT_PER_PAGE,
+ Some(MessageId::from_proto(request.payload.before_message_id)),
+ )
+ .await?
+ .into_iter()
+ .map(|msg| proto::ChannelMessage {
+ id: msg.id.to_proto(),
+ body: msg.body,
+ timestamp: msg.sent_at.unix_timestamp() as u64,
+ sender_id: msg.sender_id.to_proto(),
+ })
+ .collect::<Vec<_>>();
+ self.peer
+ .respond(
+ request.receipt(),
+ proto::GetChannelMessagesResponse {
+ done: messages.len() < MESSAGE_COUNT_PER_PAGE,
+ messages,
+ },
+ )
+ .await?;
+ Ok(())
+ }
+
async fn broadcast_in_worktree<T: proto::EnvelopedMessage>(
&self,
worktree_id: u64,
@@ -37,6 +37,7 @@ pub struct ChannelDetails {
pub struct Channel {
details: ChannelDetails,
messages: SumTree<ChannelMessage>,
+ loaded_all_messages: bool,
pending_messages: Vec<PendingChannelMessage>,
next_local_message_id: u64,
user_store: Arc<UserStore>,
@@ -70,7 +71,7 @@ pub enum ChannelListEvent {}
#[derive(Clone, Debug, PartialEq)]
pub enum ChannelEvent {
- Message {
+ MessagesAdded {
old_range: Range<usize>,
new_count: usize,
},
@@ -192,31 +193,12 @@ impl Channel {
cx.spawn(|channel, mut cx| {
async move {
let response = rpc.request(proto::JoinChannel { channel_id }).await?;
-
- let unique_user_ids = response
- .messages
- .iter()
- .map(|m| m.sender_id)
- .collect::<HashSet<_>>()
- .into_iter()
- .collect();
- user_store.load_users(unique_user_ids).await?;
-
- let mut messages = Vec::with_capacity(response.messages.len());
- for message in response.messages {
- messages.push(ChannelMessage::from_proto(message, &user_store).await?);
- }
+ let messages = messages_from_proto(response.messages, &user_store).await?;
+ let loaded_all_messages = response.done;
channel.update(&mut cx, |channel, cx| {
- let old_count = channel.messages.summary().count;
- let new_count = messages.len();
-
- channel.messages = SumTree::new();
- channel.messages.extend(messages, &());
- cx.emit(ChannelEvent::Message {
- old_range: 0..old_count,
- new_count,
- });
+ channel.insert_messages(messages, cx);
+ channel.loaded_all_messages = loaded_all_messages;
});
Ok(())
@@ -232,6 +214,7 @@ impl Channel {
rpc,
messages: Default::default(),
pending_messages: Default::default(),
+ loaded_all_messages: false,
next_local_message_id: 0,
_subscription,
}
@@ -264,15 +247,18 @@ impl Channel {
.binary_search_by_key(&local_id, |msg| msg.local_id)
{
let body = this.pending_messages.remove(i).body;
- this.insert_message(
- ChannelMessage {
- id: response.message_id,
- timestamp: OffsetDateTime::from_unix_timestamp(
- response.timestamp as i64,
- )?,
- body,
- sender,
- },
+ this.insert_messages(
+ SumTree::from_item(
+ ChannelMessage {
+ id: response.message_id,
+ timestamp: OffsetDateTime::from_unix_timestamp(
+ response.timestamp as i64,
+ )?,
+ body,
+ sender,
+ },
+ &(),
+ ),
cx,
);
}
@@ -286,6 +272,37 @@ impl Channel {
Ok(())
}
+ pub fn load_more_messages(&mut self, cx: &mut ModelContext<Self>) -> bool {
+ if !self.loaded_all_messages {
+ let rpc = self.rpc.clone();
+ let user_store = self.user_store.clone();
+ let channel_id = self.details.id;
+ if let Some(before_message_id) = self.messages.first().map(|message| message.id) {
+ cx.spawn(|this, mut cx| {
+ async move {
+ let response = rpc
+ .request(proto::GetChannelMessages {
+ channel_id,
+ before_message_id,
+ })
+ .await?;
+ let loaded_all_messages = response.done;
+ let messages = messages_from_proto(response.messages, &user_store).await?;
+ this.update(&mut cx, |this, cx| {
+ this.loaded_all_messages = loaded_all_messages;
+ this.insert_messages(messages, cx);
+ });
+ Ok(())
+ }
+ .log_err()
+ })
+ .detach();
+ return true;
+ }
+ }
+ false
+ }
+
pub fn message_count(&self) -> usize {
self.messages.summary().count
}
@@ -326,7 +343,9 @@ impl Channel {
cx.spawn(|this, mut cx| {
async move {
let message = ChannelMessage::from_proto(message, &user_store).await?;
- this.update(&mut cx, |this, cx| this.insert_message(message, cx));
+ this.update(&mut cx, |this, cx| {
+ this.insert_messages(SumTree::from_item(message, &()), cx)
+ });
Ok(())
}
.log_err()
@@ -335,27 +354,49 @@ impl Channel {
Ok(())
}
- fn insert_message(&mut self, message: ChannelMessage, cx: &mut ModelContext<Self>) {
- let mut old_cursor = self.messages.cursor::<u64, Count>();
- let mut new_messages = old_cursor.slice(&message.id, Bias::Left, &());
- let start_ix = old_cursor.sum_start().0;
- let mut end_ix = start_ix;
- if old_cursor.item().map_or(false, |m| m.id == message.id) {
- old_cursor.next(&());
- end_ix += 1;
+ fn insert_messages(&mut self, messages: SumTree<ChannelMessage>, cx: &mut ModelContext<Self>) {
+ if let Some((first_message, last_message)) = messages.first().zip(messages.last()) {
+ let mut old_cursor = self.messages.cursor::<u64, Count>();
+ let mut new_messages = old_cursor.slice(&first_message.id, Bias::Left, &());
+ let start_ix = old_cursor.sum_start().0;
+ let removed_messages = old_cursor.slice(&last_message.id, Bias::Right, &());
+ let removed_count = removed_messages.summary().count;
+ let new_count = messages.summary().count;
+ let end_ix = start_ix + removed_count;
+
+ new_messages.push_tree(messages, &());
+ new_messages.push_tree(old_cursor.suffix(&()), &());
+ drop(old_cursor);
+ self.messages = new_messages;
+
+ cx.emit(ChannelEvent::MessagesAdded {
+ old_range: start_ix..end_ix,
+ new_count,
+ });
+ cx.notify();
}
+ }
+}
- new_messages.push(message.clone(), &());
- new_messages.push_tree(old_cursor.suffix(&()), &());
- drop(old_cursor);
- self.messages = new_messages;
-
- cx.emit(ChannelEvent::Message {
- old_range: start_ix..end_ix,
- new_count: 1,
- });
- cx.notify();
+async fn messages_from_proto(
+ proto_messages: Vec<proto::ChannelMessage>,
+ user_store: &UserStore,
+) -> Result<SumTree<ChannelMessage>> {
+ let unique_user_ids = proto_messages
+ .iter()
+ .map(|m| m.sender_id)
+ .collect::<HashSet<_>>()
+ .into_iter()
+ .collect();
+ user_store.load_users(unique_user_ids).await?;
+
+ let mut messages = Vec::with_capacity(proto_messages.len());
+ for message in proto_messages {
+ messages.push(ChannelMessage::from_proto(message, &user_store).await?);
}
+ let mut result = SumTree::new();
+ result.extend(messages, &());
+ Ok(result)
}
impl From<proto::Channel> for ChannelDetails {
@@ -489,9 +530,11 @@ mod tests {
sender_id: 6,
},
],
+ done: false,
},
)
.await;
+
// Client requests all users for the received messages
let mut get_users = server.receive::<proto::GetUsers>().await;
get_users.payload.user_ids.sort();
@@ -518,7 +561,7 @@ mod tests {
assert_eq!(
channel.next_event(&cx).await,
- ChannelEvent::Message {
+ ChannelEvent::MessagesAdded {
old_range: 0..0,
new_count: 2,
}
@@ -567,7 +610,7 @@ mod tests {
assert_eq!(
channel.next_event(&cx).await,
- ChannelEvent::Message {
+ ChannelEvent::MessagesAdded {
old_range: 2..2,
new_count: 1,
}
@@ -580,7 +623,57 @@ mod tests {
.collect::<Vec<_>>(),
&[("as-cii".into(), "c".into())]
)
- })
+ });
+
+ // Scroll up to view older messages.
+ channel.update(&mut cx, |channel, cx| {
+ assert!(channel.load_more_messages(cx));
+ });
+ let get_messages = server.receive::<proto::GetChannelMessages>().await;
+ assert_eq!(get_messages.payload.channel_id, 5);
+ assert_eq!(get_messages.payload.before_message_id, 10);
+ server
+ .respond(
+ get_messages.receipt(),
+ proto::GetChannelMessagesResponse {
+ done: true,
+ messages: vec![
+ proto::ChannelMessage {
+ id: 8,
+ body: "y".into(),
+ timestamp: 998,
+ sender_id: 5,
+ },
+ proto::ChannelMessage {
+ id: 9,
+ body: "z".into(),
+ timestamp: 999,
+ sender_id: 6,
+ },
+ ],
+ },
+ )
+ .await;
+
+ assert_eq!(
+ channel.next_event(&cx).await,
+ ChannelEvent::MessagesAdded {
+ old_range: 0..0,
+ new_count: 2,
+ }
+ );
+ channel.read_with(&cx, |channel, _| {
+ assert_eq!(
+ channel
+ .messages_in_range(0..2)
+ .map(|message| (message.sender.github_login.clone(), message.body.clone()))
+ .collect::<Vec<_>>(),
+ &[
+ ("nathansobo".into(), "y".into()),
+ ("maxbrunsfeld".into(), "z".into())
+ ]
+ );
+ });
}
struct FakeServer {
@@ -22,9 +22,11 @@ pub struct ChatPanel {
pub enum Event {}
action!(Send);
+action!(LoadMoreMessages);
pub fn init(cx: &mut MutableAppContext) {
cx.add_action(ChatPanel::send);
+ cx.add_action(ChatPanel::load_more_messages);
cx.add_bindings(vec![Binding::new("enter", Send, Some("ChatPanel"))]);
}
@@ -78,6 +80,11 @@ impl ChatPanel {
let subscription = cx.subscribe(&channel, Self::channel_did_change);
self.message_list =
ListState::new(channel.read(cx).message_count(), Orientation::Bottom);
+ self.message_list.set_scroll_handler(|visible_range, cx| {
+ if visible_range.start < 5 {
+ cx.dispatch_action(LoadMoreMessages);
+ }
+ });
self.active_channel = Some((channel, subscription));
}
}
@@ -89,7 +96,7 @@ impl ChatPanel {
cx: &mut ViewContext<Self>,
) {
match event {
- ChannelEvent::Message {
+ ChannelEvent::MessagesAdded {
old_range,
new_count,
} => {
@@ -191,6 +198,14 @@ impl ChatPanel {
.log_err();
}
}
+
+ fn load_more_messages(&mut self, _: &LoadMoreMessages, cx: &mut ViewContext<Self>) {
+ if let Some((channel, _)) = self.active_channel.as_ref() {
+ channel.update(cx, |channel, cx| {
+ channel.load_more_messages(cx);
+ })
+ }
+ }
}
impl Entity for ChatPanel {
@@ -32,6 +32,8 @@ message Envelope {
SendChannelMessage send_channel_message = 27;
SendChannelMessageResponse send_channel_message_response = 28;
ChannelMessageSent channel_message_sent = 29;
+ GetChannelMessages get_channel_messages = 30;
+ GetChannelMessagesResponse get_channel_messages_response = 31;
}
}
@@ -130,6 +132,7 @@ message JoinChannel {
message JoinChannelResponse {
repeated ChannelMessage messages = 1;
+ bool done = 2;
}
message LeaveChannel {
@@ -159,6 +162,16 @@ message ChannelMessageSent {
ChannelMessage message = 2;
}
+message GetChannelMessages {
+ uint64 channel_id = 1;
+ uint64 before_message_id = 2;
+}
+
+message GetChannelMessagesResponse {
+ repeated ChannelMessage messages = 1;
+ bool done = 2;
+}
+
// Entities
message Peer {
@@ -125,6 +125,8 @@ messages!(
ChannelMessageSent,
CloseBuffer,
CloseWorktree,
+ GetChannelMessages,
+ GetChannelMessagesResponse,
GetChannels,
GetChannelsResponse,
GetUsers,
@@ -158,6 +160,7 @@ request_messages!(
(SaveBuffer, BufferSaved),
(ShareWorktree, ShareWorktreeResponse),
(SendChannelMessage, SendChannelMessageResponse),
+ (GetChannelMessages, GetChannelMessagesResponse),
);
entity_messages!(