Re-join channel when reconnecting

Antonio Scandurra and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

gpui/src/executor.rs  |  14 +++---
zed/src/channel.rs    | 100 ++++++++++++++++++++++++++++++++++----------
zed/src/chat_panel.rs |   2 
zed/src/rpc.rs        |   9 ++-
zrpc/src/conn.rs      |   6 +-
5 files changed, 93 insertions(+), 38 deletions(-)

Detailed changes

gpui/src/executor.rs 🔗

@@ -51,7 +51,7 @@ struct DeterministicState {
     forbid_parking: bool,
     block_on_ticks: RangeInclusive<usize>,
     now: Instant,
-    pending_sleeps: Vec<(Instant, barrier::Sender)>,
+    pending_timers: Vec<(Instant, barrier::Sender)>,
 }
 
 pub struct Deterministic {
@@ -71,7 +71,7 @@ impl Deterministic {
                 forbid_parking: false,
                 block_on_ticks: 0..=1000,
                 now: Instant::now(),
-                pending_sleeps: Default::default(),
+                pending_timers: Default::default(),
             })),
             parker: Default::default(),
         }
@@ -428,14 +428,14 @@ impl Foreground {
         }
     }
 
-    pub async fn sleep(&self, duration: Duration) {
+    pub async fn timer(&self, duration: Duration) {
         match self {
             Self::Deterministic(executor) => {
                 let (tx, mut rx) = barrier::channel();
                 {
                     let mut state = executor.state.lock();
                     let wakeup_at = state.now + duration;
-                    state.pending_sleeps.push((wakeup_at, tx));
+                    state.pending_timers.push((wakeup_at, tx));
                 }
                 rx.recv().await;
             }
@@ -453,11 +453,11 @@ impl Foreground {
                 let mut state = executor.state.lock();
                 state.now += duration;
                 let now = state.now;
-                let mut pending_sleeps = mem::take(&mut state.pending_sleeps);
+                let mut pending_timers = mem::take(&mut state.pending_timers);
                 drop(state);
 
-                pending_sleeps.retain(|(wakeup, _)| *wakeup > now);
-                executor.state.lock().pending_sleeps.extend(pending_sleeps);
+                pending_timers.retain(|(wakeup, _)| *wakeup > now);
+                executor.state.lock().pending_timers.extend(pending_timers);
             }
             _ => panic!("this method can only be called on a deterministic executor"),
         }

zed/src/channel.rs 🔗

@@ -11,6 +11,7 @@ use gpui::{
 use postage::prelude::Stream;
 use std::{
     collections::{HashMap, HashSet},
+    mem,
     ops::Range,
     sync::Arc,
 };
@@ -71,7 +72,7 @@ pub enum ChannelListEvent {}
 
 #[derive(Clone, Debug, PartialEq)]
 pub enum ChannelEvent {
-    MessagesAdded {
+    MessagesUpdated {
         old_range: Range<usize>,
         new_count: usize,
     },
@@ -93,26 +94,40 @@ impl ChannelList {
                 let mut status = rpc.status();
                 loop {
                     let status = status.recv().await.unwrap();
-                    let available_channels = if matches!(status, rpc::Status::Connected { .. }) {
-                        let response = rpc
-                            .request(proto::GetChannels {})
-                            .await
-                            .context("failed to fetch available channels")?;
-                        Some(response.channels.into_iter().map(Into::into).collect())
-                    } else {
-                        None
-                    };
-
-                    this.update(&mut cx, |this, cx| {
-                        if available_channels.is_none() {
-                            if this.available_channels.is_none() {
-                                return;
-                            }
-                            this.channels.clear();
+                    match status {
+                        rpc::Status::Connected { .. } => {
+                            let response = rpc
+                                .request(proto::GetChannels {})
+                                .await
+                                .context("failed to fetch available channels")?;
+                            this.update(&mut cx, |this, cx| {
+                                this.available_channels =
+                                    Some(response.channels.into_iter().map(Into::into).collect());
+
+                                let mut to_remove = Vec::new();
+                                for (channel_id, channel) in &this.channels {
+                                    if let Some(channel) = channel.upgrade(cx) {
+                                        channel.update(cx, |channel, cx| channel.rejoin(cx))
+                                    } else {
+                                        to_remove.push(*channel_id);
+                                    }
+                                }
+
+                                for channel_id in to_remove {
+                                    this.channels.remove(&channel_id);
+                                }
+                                cx.notify();
+                            });
                         }
-                        this.available_channels = available_channels;
-                        cx.notify();
-                    });
+                        rpc::Status::Disconnected { .. } => {
+                            this.update(&mut cx, |this, cx| {
+                                this.available_channels = None;
+                                this.channels.clear();
+                                cx.notify();
+                            });
+                        }
+                        _ => {}
+                    }
                 }
             }
             .log_err()
@@ -282,6 +297,43 @@ impl Channel {
         false
     }
 
+    pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
+        let user_store = self.user_store.clone();
+        let rpc = self.rpc.clone();
+        let channel_id = self.details.id;
+        cx.spawn(|channel, mut cx| {
+            async move {
+                let response = rpc.request(proto::JoinChannel { channel_id }).await?;
+                let messages = messages_from_proto(response.messages, &user_store).await?;
+                let loaded_all_messages = response.done;
+
+                channel.update(&mut cx, |channel, cx| {
+                    if let Some((first_new_message, last_old_message)) =
+                        messages.first().zip(channel.messages.last())
+                    {
+                        if first_new_message.id > last_old_message.id {
+                            let old_messages = mem::take(&mut channel.messages);
+                            cx.emit(ChannelEvent::MessagesUpdated {
+                                old_range: 0..old_messages.summary().count,
+                                new_count: 0,
+                            });
+                            channel.loaded_all_messages = loaded_all_messages;
+                        }
+                    }
+
+                    channel.insert_messages(messages, cx);
+                    if loaded_all_messages {
+                        channel.loaded_all_messages = loaded_all_messages;
+                    }
+                });
+
+                Ok(())
+            }
+            .log_err()
+        })
+        .detach();
+    }
+
     pub fn message_count(&self) -> usize {
         self.messages.summary().count
     }
@@ -347,7 +399,7 @@ impl Channel {
             drop(old_cursor);
             self.messages = new_messages;
 
-            cx.emit(ChannelEvent::MessagesAdded {
+            cx.emit(ChannelEvent::MessagesUpdated {
                 old_range: start_ix..end_ix,
                 new_count,
             });
@@ -539,7 +591,7 @@ mod tests {
 
         assert_eq!(
             channel.next_event(&cx).await,
-            ChannelEvent::MessagesAdded {
+            ChannelEvent::MessagesUpdated {
                 old_range: 0..0,
                 new_count: 2,
             }
@@ -588,7 +640,7 @@ mod tests {
 
         assert_eq!(
             channel.next_event(&cx).await,
-            ChannelEvent::MessagesAdded {
+            ChannelEvent::MessagesUpdated {
                 old_range: 2..2,
                 new_count: 1,
             }
@@ -635,7 +687,7 @@ mod tests {
 
         assert_eq!(
             channel.next_event(&cx).await,
-            ChannelEvent::MessagesAdded {
+            ChannelEvent::MessagesUpdated {
                 old_range: 0..0,
                 new_count: 2,
             }

zed/src/chat_panel.rs 🔗

@@ -194,7 +194,7 @@ impl ChatPanel {
         cx: &mut ViewContext<Self>,
     ) {
         match event {
-            ChannelEvent::MessagesAdded {
+            ChannelEvent::MessagesUpdated {
                 old_range,
                 new_count,
             } => {

zed/src/rpc.rs 🔗

@@ -140,7 +140,7 @@ impl Client {
                 state._maintain_connection = Some(cx.foreground().spawn(async move {
                     let mut next_ping_id = 0;
                     loop {
-                        foreground.sleep(heartbeat_interval).await;
+                        foreground.timer(heartbeat_interval).await;
                         this.request(proto::Ping { id: next_ping_id })
                             .await
                             .unwrap();
@@ -162,7 +162,7 @@ impl Client {
                             },
                             &cx,
                         );
-                        foreground.sleep(delay).await;
+                        foreground.timer(delay).await;
                         delay_seconds = (delay_seconds * 2).min(300);
                     }
                 }));
@@ -233,11 +233,14 @@ impl Client {
     ) -> anyhow::Result<()> {
         let was_disconnected = match *self.status().borrow() {
             Status::Disconnected => true,
+            Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
+                false
+            }
             Status::Connected { .. }
             | Status::Connecting { .. }
             | Status::Reconnecting { .. }
+            | Status::Authenticating
             | Status::Reauthenticating => return Ok(()),
-            _ => false,
         };
 
         if was_disconnected {

zrpc/src/conn.rs 🔗

@@ -72,12 +72,12 @@ impl Conn {
         let rx = stream::select(
             rx.map(Ok),
             kill_rx.filter_map(|kill| {
-                if let Some(_) = kill {
+                if kill.is_none() {
+                    future::ready(None)
+                } else {
                     future::ready(Some(Err(
                         Error::new(ErrorKind::Other, "connection killed").into()
                     )))
-                } else {
-                    future::ready(None)
                 }
             }),
         );