Fetch messages when joining a channel

Nathan Sobo created

Change summary

gpui/src/app.rs      | 20 ++++++++++
server/src/tests.rs  | 12 ++++--
zed/src/channel.rs   | 81 +++++++++++++++++++++++++++++++++++++++++----
zrpc/proto/zed.proto |  9 +++-
zrpc/src/proto.rs    |  1 
5 files changed, 108 insertions(+), 15 deletions(-)

Detailed changes

gpui/src/app.rs 🔗

@@ -14,7 +14,7 @@ use keymap::MatchResult;
 use parking_lot::{Mutex, RwLock};
 use pathfinder_geometry::{rect::RectF, vector::vec2f};
 use platform::Event;
-use postage::{mpsc, sink::Sink as _, stream::Stream as _};
+use postage::{mpsc, oneshot, sink::Sink as _, stream::Stream as _};
 use smol::prelude::*;
 use std::{
     any::{type_name, Any, TypeId},
@@ -2310,6 +2310,24 @@ impl<T: Entity> ModelHandle<T> {
         cx.update_model(self, update)
     }
 
+    pub fn next_notification(&self, cx: &TestAppContext) -> impl Future<Output = ()> {
+        let (tx, mut rx) = oneshot::channel();
+        let mut tx = Some(tx);
+
+        let mut cx = cx.cx.borrow_mut();
+        self.update(&mut *cx, |_, cx| {
+            cx.observe(self, move |_, _, _| {
+                if let Some(mut tx) = tx.take() {
+                    tx.blocking_send(()).ok();
+                }
+            });
+        });
+
+        async move {
+            rx.recv().await;
+        }
+    }
+
     pub fn condition(
         &self,
         cx: &TestAppContext,

server/src/tests.rs 🔗

@@ -480,8 +480,6 @@ async fn test_peer_disconnection(mut cx_a: TestAppContext, cx_b: TestAppContext)
 
 #[gpui::test]
 async fn test_basic_chat(mut cx_a: TestAppContext, cx_b: TestAppContext) {
-    let lang_registry = Arc::new(LanguageRegistry::new());
-
     // Connect to a server as 2 clients.
     let mut server = TestServer::start().await;
     let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
@@ -531,8 +529,14 @@ async fn test_basic_chat(mut cx_a: TestAppContext, cx_b: TestAppContext) {
         )
     });
 
-    let channel_a = channels_a.read_with(&cx_a, |this, cx| {
-        this.get_channel(channel_id.to_proto(), &cx).unwrap()
+    let channel_a = channels_a.update(&mut cx_a, |this, cx| {
+        this.get_channel(channel_id.to_proto(), cx).unwrap()
+    });
+
+    channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_none()));
+    channel_a.next_notification(&cx_a).await;
+    channel_a.read_with(&cx_a, |channel, _| {
+        assert_eq!(channel.messages().unwrap().len(), 1);
     });
 }
 

zed/src/channel.rs 🔗

@@ -1,8 +1,11 @@
 use crate::rpc::{self, Client};
 use anyhow::{Context, Result};
-use gpui::{AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, WeakModelHandle};
+use gpui::{
+    executor, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
+    WeakModelHandle,
+};
 use std::{
-    collections::{HashMap, VecDeque},
+    collections::{hash_map, HashMap, VecDeque},
     sync::Arc,
 };
 use zrpc::{
@@ -16,7 +19,7 @@ pub struct ChannelList {
     rpc: Arc<Client>,
 }
 
-#[derive(Debug, PartialEq)]
+#[derive(Clone, Debug, PartialEq)]
 pub struct ChannelDetails {
     pub id: u64,
     pub name: String,
@@ -28,6 +31,7 @@ pub struct Channel {
     messages: Option<VecDeque<ChannelMessage>>,
     rpc: Arc<Client>,
     _subscription: rpc::Subscription,
+    background: Arc<executor::Background>,
 }
 
 pub struct ChannelMessage {
@@ -57,11 +61,28 @@ impl ChannelList {
         &self.available_channels
     }
 
-    pub fn get_channel(&self, id: u64, cx: &AppContext) -> Option<ModelHandle<Channel>> {
-        self.channels
-            .get(&id)
-            .cloned()
-            .and_then(|handle| handle.upgrade(cx))
+    pub fn get_channel(
+        &mut self,
+        id: u64,
+        cx: &mut MutableAppContext,
+    ) -> Option<ModelHandle<Channel>> {
+        match self.channels.entry(id) {
+            hash_map::Entry::Occupied(entry) => entry.get().upgrade(cx),
+            hash_map::Entry::Vacant(entry) => {
+                if let Some(details) = self
+                    .available_channels
+                    .iter()
+                    .find(|details| details.id == id)
+                {
+                    let rpc = self.rpc.clone();
+                    let channel = cx.add_model(|cx| Channel::new(details.clone(), rpc, cx));
+                    entry.insert(channel.downgrade());
+                    Some(channel)
+                } else {
+                    None
+                }
+            }
+        }
     }
 }
 
@@ -73,12 +94,31 @@ impl Channel {
     pub fn new(details: ChannelDetails, rpc: Arc<Client>, cx: &mut ModelContext<Self>) -> Self {
         let _subscription = rpc.subscribe_from_model(details.id, cx, Self::handle_message_sent);
 
+        {
+            let rpc = rpc.clone();
+            let channel_id = details.id;
+            cx.spawn(|channel, mut cx| async move {
+                match rpc.request(proto::JoinChannel { channel_id }).await {
+                    Ok(response) => {
+                        let messages = response.messages.into_iter().map(Into::into).collect();
+                        channel.update(&mut cx, |channel, cx| {
+                            channel.messages = Some(messages);
+                            cx.notify();
+                        })
+                    }
+                    Err(error) => log::error!("error joining channel: {}", error),
+                }
+            })
+            .detach();
+        }
+
         Self {
             details,
             rpc,
             first_message_id: None,
             messages: None,
             _subscription,
+            background: cx.background().clone(),
         }
     }
 
@@ -90,6 +130,25 @@ impl Channel {
     ) -> Result<()> {
         Ok(())
     }
+
+    pub fn messages(&self) -> Option<&VecDeque<ChannelMessage>> {
+        self.messages.as_ref()
+    }
+}
+
+// TODO: Implement the server side of leaving a channel
+impl Drop for Channel {
+    fn drop(&mut self) {
+        let rpc = self.rpc.clone();
+        let channel_id = self.details.id;
+        self.background
+            .spawn(async move {
+                if let Err(error) = rpc.send(proto::LeaveChannel { channel_id }).await {
+                    log::error!("error leaving channel: {}", error);
+                };
+            })
+            .detach()
+    }
 }
 
 impl From<proto::Channel> for ChannelDetails {
@@ -100,3 +159,9 @@ impl From<proto::Channel> for ChannelDetails {
         }
     }
 }
+
+impl From<proto::ChannelMessage> for ChannelMessage {
+    fn from(message: proto::ChannelMessage) -> Self {
+        ChannelMessage { id: message.id }
+    }
+}

zrpc/proto/zed.proto 🔗

@@ -30,8 +30,9 @@ message Envelope {
         GetUsersResponse get_users_response = 25;
         JoinChannel join_channel = 26;
         JoinChannelResponse join_channel_response = 27;
-        SendChannelMessage send_channel_message = 28;
-        ChannelMessageSent channel_message_sent = 29;
+        LeaveChannel leave_channel = 28;
+        SendChannelMessage send_channel_message = 29;
+        ChannelMessageSent channel_message_sent = 30;
     }
 }
 
@@ -141,6 +142,10 @@ message JoinChannelResponse {
     repeated ChannelMessage messages = 1;
 }
 
+message LeaveChannel {
+    uint64 channel_id = 1;
+}
+
 message GetUsers {
     repeated uint64 user_ids = 1;
 }

zrpc/src/proto.rs 🔗

@@ -138,6 +138,7 @@ messages!(
     GetUsersResponse,
     JoinChannel,
     JoinChannelResponse,
+    LeaveChannel,
     OpenBuffer,
     OpenBufferResponse,
     OpenWorktree,