Start on a `Client::status` method that can be observed

Antonio Scandurra created

Change summary

server/src/rpc.rs     |   2 
zed/src/channel.rs    |  21 +++-----
zed/src/chat_panel.rs |  30 +++++++-----
zed/src/rpc.rs        | 105 +++++++++++++++++++++++++++++++++-----------
zrpc/src/peer.rs      |   5 ++
5 files changed, 111 insertions(+), 52 deletions(-)

Detailed changes

server/src/rpc.rs 🔗

@@ -1714,7 +1714,7 @@ mod tests {
                 )
                 .detach();
             client
-                .add_connection(user_id.to_proto(), client_conn, &cx.to_async())
+                .set_connection(user_id.to_proto(), client_conn, &cx.to_async())
                 .await
                 .unwrap();
             (user_id, client)

zed/src/channel.rs 🔗

@@ -90,18 +90,15 @@ impl ChannelList {
         let _task = cx.spawn(|this, mut cx| {
             let rpc = rpc.clone();
             async move {
-                let mut user_id = rpc.user_id();
+                let mut status = rpc.status();
                 loop {
-                    let available_channels = if user_id.recv().await.unwrap().is_some() {
-                        Some(
-                            rpc.request(proto::GetChannels {})
-                                .await
-                                .context("failed to fetch available channels")?
-                                .channels
-                                .into_iter()
-                                .map(Into::into)
-                                .collect(),
-                        )
+                    let status = status.recv().await.unwrap();
+                    let available_channels = if matches!(status, rpc::Status::Connected { .. }) {
+                        let response = rpc
+                            .request(proto::GetChannels {})
+                            .await
+                            .context("failed to fetch available channels")?;
+                        Some(response.channels.into_iter().map(Into::into).collect())
                     } else {
                         None
                     };
@@ -671,7 +668,7 @@ mod tests {
             cx.background().spawn(io).detach();
 
             client
-                .add_connection(user_id, client_conn, &cx.to_async())
+                .set_connection(user_id, client_conn, &cx.to_async())
                 .await
                 .unwrap();
 

zed/src/chat_panel.rs 🔗

@@ -3,7 +3,7 @@ use std::sync::Arc;
 use crate::{
     channel::{Channel, ChannelEvent, ChannelList, ChannelMessage},
     editor::Editor,
-    rpc::Client,
+    rpc::{self, Client},
     theme,
     util::{ResultExt, TryFutureExt},
     Settings,
@@ -14,10 +14,10 @@ use gpui::{
     keymap::Binding,
     platform::CursorStyle,
     views::{ItemType, Select, SelectStyle},
-    AppContext, Entity, ModelHandle, MutableAppContext, RenderContext, Subscription, View,
+    AppContext, Entity, ModelHandle, MutableAppContext, RenderContext, Subscription, Task, View,
     ViewContext, ViewHandle,
 };
-use postage::watch;
+use postage::{prelude::Stream, watch};
 use time::{OffsetDateTime, UtcOffset};
 
 const MESSAGE_LOADING_THRESHOLD: usize = 50;
@@ -31,6 +31,7 @@ pub struct ChatPanel {
     channel_select: ViewHandle<Select>,
     settings: watch::Receiver<Settings>,
     local_timezone: UtcOffset,
+    _status_observer: Task<()>,
 }
 
 pub enum Event {}
@@ -98,6 +99,14 @@ impl ChatPanel {
                 cx.dispatch_action(LoadMoreMessages);
             }
         });
+        let _status_observer = cx.spawn(|this, mut cx| {
+            let mut status = rpc.status();
+            async move {
+                while let Some(_) = status.recv().await {
+                    this.update(&mut cx, |_, cx| cx.notify());
+                }
+            }
+        });
 
         let mut this = Self {
             rpc,
@@ -108,6 +117,7 @@ impl ChatPanel {
             channel_select,
             settings,
             local_timezone: cx.platform().local_timezone(),
+            _status_observer,
         };
 
         this.init_active_channel(cx);
@@ -153,6 +163,7 @@ impl ChatPanel {
         if let Some(active_channel) = active_channel {
             self.set_active_channel(active_channel, cx);
         } else {
+            self.message_list.reset(0);
             self.active_channel = None;
         }
 
@@ -357,10 +368,6 @@ impl ChatPanel {
             })
         }
     }
-
-    fn is_signed_in(&self) -> bool {
-        self.rpc.user_id().borrow().is_some()
-    }
 }
 
 impl Entity for ChatPanel {
@@ -374,10 +381,9 @@ impl View for ChatPanel {
 
     fn render(&mut self, cx: &mut RenderContext<Self>) -> ElementBox {
         let theme = &self.settings.borrow().theme;
-        let element = if self.is_signed_in() {
-            self.render_channel()
-        } else {
-            self.render_sign_in_prompt(cx)
+        let element = match *self.rpc.status().borrow() {
+            rpc::Status::Connected { .. } => self.render_channel(),
+            _ => self.render_sign_in_prompt(cx),
         };
         ConstrainedBox::new(
             Container::new(element)
@@ -389,7 +395,7 @@ impl View for ChatPanel {
     }
 
     fn on_focus(&mut self, cx: &mut ViewContext<Self>) {
-        if self.is_signed_in() {
+        if matches!(*self.rpc.status().borrow(), rpc::Status::Connected { .. }) {
             cx.focus(&self.input_editor);
         }
     }

zed/src/rpc.rs 🔗

@@ -31,9 +31,20 @@ pub struct Client {
     state: RwLock<ClientState>,
 }
 
+#[derive(Copy, Clone, Debug)]
+pub enum Status {
+    Disconnected,
+    Connecting,
+    ConnectionError,
+    Connected {
+        connection_id: ConnectionId,
+        user_id: u64,
+    },
+    ConnectionLost,
+}
+
 struct ClientState {
-    connection_id: Option<ConnectionId>,
-    user_id: (watch::Sender<Option<u64>>, watch::Receiver<Option<u64>>),
+    status: (watch::Sender<Status>, watch::Receiver<Status>),
     entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
     model_handlers: HashMap<
         (TypeId, u64),
@@ -44,8 +55,7 @@ struct ClientState {
 impl Default for ClientState {
     fn default() -> Self {
         Self {
-            connection_id: Default::default(),
-            user_id: watch::channel(),
+            status: watch::channel_with(Status::Disconnected),
             entity_id_extractors: Default::default(),
             model_handlers: Default::default(),
         }
@@ -80,8 +90,14 @@ impl Client {
         })
     }
 
-    pub fn user_id(&self) -> watch::Receiver<Option<u64>> {
-        self.state.read().user_id.1.clone()
+    pub fn status(&self) -> watch::Receiver<Status> {
+        self.state.read().status.1.clone()
+    }
+
+    async fn set_status(&self, status: Status) -> Result<()> {
+        let mut state = self.state.write();
+        state.status.0.send(status).await?;
+        Ok(())
     }
 
     pub fn subscribe_from_model<T, M, F>(
@@ -141,43 +157,64 @@ impl Client {
         self: &Arc<Self>,
         cx: &AsyncAppContext,
     ) -> anyhow::Result<()> {
-        if self.state.read().connection_id.is_some() {
+        if matches!(
+            *self.status().borrow(),
+            Status::Connecting | Status::Connected { .. }
+        ) {
             return Ok(());
         }
 
         let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
         let user_id = user_id.parse::<u64>()?;
+
+        self.set_status(Status::Connecting).await?;
+        match self.connect(user_id, &access_token, cx).await {
+            Ok(()) => {
+                log::info!("connected to rpc address {}", *ZED_SERVER_URL);
+                Ok(())
+            }
+            Err(err) => {
+                self.set_status(Status::ConnectionError).await?;
+                Err(err)
+            }
+        }
+    }
+
+    async fn connect(
+        self: &Arc<Self>,
+        user_id: u64,
+        access_token: &str,
+        cx: &AsyncAppContext,
+    ) -> Result<()> {
         let request =
             Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
-
         if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
             let stream = smol::net::TcpStream::connect(host).await?;
             let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
             let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
                 .await
                 .context("websocket handshake")?;
-            self.add_connection(user_id, stream, cx).await?;
+            self.set_connection(user_id, stream, cx).await?;
+            Ok(())
         } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
             let stream = smol::net::TcpStream::connect(host).await?;
             let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
             let (stream, _) = async_tungstenite::client_async(request, stream)
                 .await
                 .context("websocket handshake")?;
-            self.add_connection(user_id, stream, cx).await?;
+            self.set_connection(user_id, stream, cx).await?;
+            Ok(())
         } else {
-            return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?;
-        };
-
-        log::info!("connected to rpc address {}", *ZED_SERVER_URL);
-        Ok(())
+            return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL));
+        }
     }
 
-    pub async fn add_connection<Conn>(
+    pub async fn set_connection<Conn>(
         self: &Arc<Self>,
         user_id: u64,
         conn: Conn,
         cx: &AsyncAppContext,
-    ) -> anyhow::Result<()>
+    ) -> Result<()>
     where
         Conn: 'static
             + futures::Sink<WebSocketMessage, Error = WebSocketError>
@@ -218,16 +255,28 @@ impl Client {
                 })
                 .detach();
         }
-        cx.background()
+
+        self.set_status(Status::Connected {
+            connection_id,
+            user_id,
+        })
+        .await?;
+
+        let handle_io = cx.background().spawn(handle_io);
+        let this = self.clone();
+        cx.foreground()
             .spawn(async move {
-                if let Err(error) = handle_io.await {
-                    log::error!("connection error: {:?}", error);
+                match handle_io.await {
+                    Ok(()) => {
+                        let _ = this.set_status(Status::Disconnected).await;
+                    }
+                    Err(err) => {
+                        log::error!("connection error: {:?}", err);
+                        let _ = this.set_status(Status::ConnectionLost).await;
+                    }
                 }
             })
             .detach();
-        let mut state = self.state.write();
-        state.connection_id = Some(connection_id);
-        state.user_id.0.send(Some(user_id)).await?;
         Ok(())
     }
 
@@ -316,14 +365,16 @@ impl Client {
     pub async fn disconnect(&self) -> Result<()> {
         let conn_id = self.connection_id()?;
         self.peer.disconnect(conn_id).await;
+        self.set_status(Status::Disconnected).await?;
         Ok(())
     }
 
     fn connection_id(&self) -> Result<ConnectionId> {
-        self.state
-            .read()
-            .connection_id
-            .ok_or_else(|| anyhow!("not connected"))
+        if let Status::Connected { connection_id, .. } = *self.status().borrow() {
+            Ok(connection_id)
+        } else {
+            Err(anyhow!("not connected"))
+        }
     }
 
     pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {

zrpc/src/peer.rs 🔗

@@ -127,6 +127,7 @@ impl Peer {
         let mut writer = MessageStream::new(tx);
         let mut reader = MessageStream::new(rx);
 
+        let this = self.clone();
         let response_channels = connection.response_channels.clone();
         let handle_io = async move {
             loop {
@@ -147,6 +148,7 @@ impl Peer {
                                     if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
                                         if incoming_tx.send(envelope).await.is_err() {
                                             response_channels.lock().await.clear();
+                                            this.connections.write().await.remove(&connection_id);
                                             return Ok(())
                                         }
                                     } else {
@@ -158,6 +160,7 @@ impl Peer {
                             }
                             Err(error) => {
                                 response_channels.lock().await.clear();
+                                this.connections.write().await.remove(&connection_id);
                                 Err(error).context("received invalid RPC message")?;
                             }
                         },
@@ -165,11 +168,13 @@ impl Peer {
                             Some(outgoing) => {
                                 if let Err(result) = writer.write_message(&outgoing).await {
                                     response_channels.lock().await.clear();
+                                    this.connections.write().await.remove(&connection_id);
                                     Err(result).context("failed to write RPC message")?;
                                 }
                             }
                             None => {
                                 response_channels.lock().await.clear();
+                                this.connections.write().await.remove(&connection_id);
                                 return Ok(())
                             }
                         }