Add initial unit test for channel list

Max Brunsfeld and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

gpui/src/app.rs       |  47 +++++++++++
zed/Cargo.toml        |   1 
zed/src/channel.rs    | 184 ++++++++++++++++++++++++++++++++++++++++++++
zed/src/chat_panel.rs |  16 ++
zrpc/src/proto.rs     |   5 -
5 files changed, 240 insertions(+), 13 deletions(-)

Detailed changes

gpui/src/app.rs 🔗

@@ -2356,6 +2356,53 @@ impl<T: Entity> ModelHandle<T> {
         cx.update_model(self, update)
     }
 
+    pub fn next_notification(&self, cx: &TestAppContext) -> impl Future<Output = ()> {
+        let (mut tx, mut rx) = mpsc::channel(1);
+        let mut cx = cx.cx.borrow_mut();
+        let subscription = cx.observe(self, move |_, _| {
+            tx.blocking_send(()).ok();
+        });
+
+        let duration = if std::env::var("CI").is_ok() {
+            Duration::from_secs(5)
+        } else {
+            Duration::from_secs(1)
+        };
+
+        async move {
+            let notification = timeout(duration, rx.recv())
+                .await
+                .expect("next notification timed out");
+            drop(subscription);
+            notification.expect("model dropped while test was waiting for its next notification")
+        }
+    }
+
+    pub fn next_event(&self, cx: &TestAppContext) -> impl Future<Output = T::Event>
+    where
+        T::Event: Clone,
+    {
+        let (mut tx, mut rx) = mpsc::channel(1);
+        let mut cx = cx.cx.borrow_mut();
+        let subscription = cx.subscribe(self, move |_, event, _| {
+            tx.blocking_send(event.clone()).ok();
+        });
+
+        let duration = if std::env::var("CI").is_ok() {
+            Duration::from_secs(5)
+        } else {
+            Duration::from_secs(1)
+        };
+
+        async move {
+            let event = timeout(duration, rx.recv())
+                .await
+                .expect("next event timed out");
+            drop(subscription);
+            event.expect("model dropped while test was waiting for its next event")
+        }
+    }
+
     pub fn condition(
         &self,
         cx: &TestAppContext,

zed/Cargo.toml 🔗

@@ -61,6 +61,7 @@ env_logger = "0.8"
 serde_json = { version = "1.0.64", features = ["preserve_order"] }
 tempdir = { version = "0.3.7" }
 unindent = "0.1.7"
+zrpc = { path = "../zrpc", features = ["test-support"] }
 
 [package.metadata.bundle]
 icon = ["app-icon@2x.png", "app-icon.png"]

zed/src/channel.rs 🔗

@@ -40,7 +40,7 @@ pub struct Channel {
     _subscription: rpc::Subscription,
 }
 
-#[derive(Clone, Debug)]
+#[derive(Clone, Debug, PartialEq)]
 pub struct ChannelMessage {
     pub id: u64,
     pub sender_id: u64,
@@ -63,10 +63,11 @@ struct Count(usize);
 
 pub enum ChannelListEvent {}
 
+#[derive(Clone, Debug, PartialEq)]
 pub enum ChannelEvent {
     Message {
         old_range: Range<usize>,
-        message: ChannelMessage,
+        new_count: usize,
     },
 }
 
@@ -170,11 +171,16 @@ impl Channel {
             cx.spawn(|channel, mut cx| async move {
                 match rpc.request(proto::JoinChannel { channel_id }).await {
                     Ok(response) => channel.update(&mut cx, |channel, cx| {
+                        let old_count = channel.messages.summary().count.0;
+                        let new_count = response.messages.len();
                         channel.messages = SumTree::new();
                         channel
                             .messages
                             .extend(response.messages.into_iter().map(Into::into), &());
-                        cx.notify();
+                        cx.emit(ChannelEvent::Message {
+                            old_range: 0..old_count,
+                            new_count,
+                        });
                     }),
                     Err(error) => log::error!("error joining channel: {}", error),
                 }
@@ -235,6 +241,12 @@ impl Channel {
         &self.messages
     }
 
+    pub fn messages_in_range(&self, range: Range<usize>) -> impl Iterator<Item = &ChannelMessage> {
+        let mut cursor = self.messages.cursor::<Count, ()>();
+        cursor.seek(&Count(range.start), Bias::Left, &());
+        cursor.take(range.len())
+    }
+
     pub fn pending_messages(&self) -> &[PendingChannelMessage] {
         &self.pending_messages
     }
@@ -277,7 +289,7 @@ impl Channel {
 
         cx.emit(ChannelEvent::Message {
             old_range: start_ix..end_ix,
-            message,
+            new_count: 1,
         });
         cx.notify();
     }
@@ -334,3 +346,167 @@ impl<'a> sum_tree::Dimension<'a, ChannelMessageSummary> for Count {
         self.0 += summary.count.0;
     }
 }
+
+impl<'a> sum_tree::SeekDimension<'a, ChannelMessageSummary> for Count {
+    fn cmp(&self, other: &Self, _: &()) -> std::cmp::Ordering {
+        Ord::cmp(&self.0, &other.0)
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use gpui::TestAppContext;
+    use postage::mpsc::Receiver;
+    use zrpc::{test::Channel, ConnectionId, Peer, Receipt};
+
+    #[gpui::test]
+    async fn test_channel_messages(mut cx: TestAppContext) {
+        let user_id = 5;
+        let client = Client::new();
+        let mut server = FakeServer::for_client(user_id, &client, &cx).await;
+
+        let channel_list = cx.add_model(|cx| ChannelList::new(client.clone(), cx));
+        channel_list.read_with(&cx, |list, _| assert_eq!(list.available_channels(), None));
+
+        // Get the available channels.
+        let message = server.receive::<proto::GetChannels>().await;
+        server
+            .respond(
+                message.receipt(),
+                proto::GetChannelsResponse {
+                    channels: vec![proto::Channel {
+                        id: 5,
+                        name: "the-channel".to_string(),
+                    }],
+                },
+            )
+            .await;
+        channel_list.next_notification(&cx).await;
+        channel_list.read_with(&cx, |list, _| {
+            assert_eq!(
+                list.available_channels().unwrap(),
+                &[ChannelDetails {
+                    id: 5,
+                    name: "the-channel".into(),
+                }]
+            )
+        });
+
+        // Join a channel and populate its existing messages.
+        let channel = channel_list
+            .update(&mut cx, |list, cx| {
+                let channel_id = list.available_channels().unwrap()[0].id;
+                list.get_channel(channel_id, cx)
+            })
+            .unwrap();
+        channel.read_with(&cx, |channel, _| assert!(channel.messages().is_empty()));
+        let message = server.receive::<proto::JoinChannel>().await;
+        server
+            .respond(
+                message.receipt(),
+                proto::JoinChannelResponse {
+                    messages: vec![
+                        proto::ChannelMessage {
+                            id: 10,
+                            body: "a".into(),
+                            timestamp: 1000,
+                            sender_id: 5,
+                        },
+                        proto::ChannelMessage {
+                            id: 11,
+                            body: "b".into(),
+                            timestamp: 1001,
+                            sender_id: 5,
+                        },
+                    ],
+                },
+            )
+            .await;
+        assert_eq!(
+            channel.next_event(&cx).await,
+            ChannelEvent::Message {
+                old_range: 0..0,
+                new_count: 2,
+            }
+        );
+        channel.read_with(&cx, |channel, _| {
+            assert_eq!(
+                channel
+                    .messages_in_range(0..2)
+                    .map(|message| &message.body)
+                    .collect::<Vec<_>>(),
+                &["a", "b"]
+            );
+        });
+
+        // Receive a new message.
+        server
+            .send(proto::ChannelMessageSent {
+                channel_id: channel.read_with(&cx, |channel, _| channel.details.id),
+                message: Some(proto::ChannelMessage {
+                    id: 12,
+                    body: "c".into(),
+                    timestamp: 1002,
+                    sender_id: 5,
+                }),
+            })
+            .await;
+        assert_eq!(
+            channel.next_event(&cx).await,
+            ChannelEvent::Message {
+                old_range: 2..2,
+                new_count: 1,
+            }
+        );
+    }
+
+    struct FakeServer {
+        peer: Arc<Peer>,
+        incoming: Receiver<Box<dyn proto::AnyTypedEnvelope>>,
+        connection_id: ConnectionId,
+    }
+
+    impl FakeServer {
+        async fn for_client(user_id: u64, client: &Arc<Client>, cx: &TestAppContext) -> Self {
+            let (client_conn, server_conn) = Channel::bidirectional();
+            let peer = Peer::new();
+            let (connection_id, io, incoming) = peer.add_connection(server_conn).await;
+            cx.background().spawn(io).detach();
+
+            client
+                .add_connection(user_id, client_conn, cx.to_async())
+                .await
+                .unwrap();
+
+            Self {
+                peer,
+                incoming,
+                connection_id,
+            }
+        }
+
+        async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
+            self.peer.send(self.connection_id, message).await.unwrap();
+        }
+
+        async fn receive<M: proto::EnvelopedMessage>(&mut self) -> TypedEnvelope<M> {
+            *self
+                .incoming
+                .recv()
+                .await
+                .unwrap()
+                .into_any()
+                .downcast::<TypedEnvelope<M>>()
+                .unwrap()
+        }
+
+        async fn respond<T: proto::RequestMessage>(
+            &self,
+            receipt: Receipt<T>,
+            response: T::Response,
+        ) {
+            self.peer.respond(receipt, response).await.unwrap()
+        }
+    }
+}

zed/src/chat_panel.rs 🔗

@@ -89,14 +89,22 @@ impl ChatPanel {
 
     fn channel_did_change(
         &mut self,
-        _: ModelHandle<Channel>,
+        channel: ModelHandle<Channel>,
         event: &ChannelEvent,
         cx: &mut ViewContext<Self>,
     ) {
         match event {
-            ChannelEvent::Message { old_range, message } => {
-                self.messages
-                    .splice(old_range.clone(), Some(self.render_message(message)));
+            ChannelEvent::Message {
+                old_range,
+                new_count,
+            } => {
+                self.messages.splice(
+                    old_range.clone(),
+                    channel
+                        .read(cx)
+                        .messages_in_range(old_range.start..(old_range.start + new_count))
+                        .map(|message| self.render_message(message)),
+                );
             }
         }
         cx.notify();

zrpc/src/proto.rs 🔗

@@ -19,7 +19,6 @@ pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
         responding_to: Option<u32>,
         original_sender_id: Option<u32>,
     ) -> Envelope;
-    fn matches_envelope(envelope: &Envelope) -> bool;
     fn from_envelope(envelope: Envelope) -> Option<Self>;
 }
 
@@ -90,10 +89,6 @@ macro_rules! messages {
                     }
                 }
 
-                fn matches_envelope(envelope: &Envelope) -> bool {
-                    matches!(&envelope.payload, Some(envelope::Payload::$name(_)))
-                }
-
                 fn from_envelope(envelope: Envelope) -> Option<Self> {
                     if let Some(envelope::Payload::$name(msg)) = envelope.payload {
                         Some(msg)