Detailed changes
@@ -1714,7 +1714,7 @@ mod tests {
)
.detach();
client
- .add_connection(user_id.to_proto(), client_conn, &cx.to_async())
+ .set_connection(user_id.to_proto(), client_conn, &cx.to_async())
.await
.unwrap();
(user_id, client)
@@ -90,18 +90,15 @@ impl ChannelList {
let _task = cx.spawn(|this, mut cx| {
let rpc = rpc.clone();
async move {
- let mut user_id = rpc.user_id();
+ let mut status = rpc.status();
loop {
- let available_channels = if user_id.recv().await.unwrap().is_some() {
- Some(
- rpc.request(proto::GetChannels {})
- .await
- .context("failed to fetch available channels")?
- .channels
- .into_iter()
- .map(Into::into)
- .collect(),
- )
+ let status = status.recv().await.unwrap();
+ let available_channels = if matches!(status, rpc::Status::Connected { .. }) {
+ let response = rpc
+ .request(proto::GetChannels {})
+ .await
+ .context("failed to fetch available channels")?;
+ Some(response.channels.into_iter().map(Into::into).collect())
} else {
None
};
@@ -671,7 +668,7 @@ mod tests {
cx.background().spawn(io).detach();
client
- .add_connection(user_id, client_conn, &cx.to_async())
+ .set_connection(user_id, client_conn, &cx.to_async())
.await
.unwrap();
@@ -3,7 +3,7 @@ use std::sync::Arc;
use crate::{
channel::{Channel, ChannelEvent, ChannelList, ChannelMessage},
editor::Editor,
- rpc::Client,
+ rpc::{self, Client},
theme,
util::{ResultExt, TryFutureExt},
Settings,
@@ -14,10 +14,10 @@ use gpui::{
keymap::Binding,
platform::CursorStyle,
views::{ItemType, Select, SelectStyle},
- AppContext, Entity, ModelHandle, MutableAppContext, RenderContext, Subscription, View,
+ AppContext, Entity, ModelHandle, MutableAppContext, RenderContext, Subscription, Task, View,
ViewContext, ViewHandle,
};
-use postage::watch;
+use postage::{prelude::Stream, watch};
use time::{OffsetDateTime, UtcOffset};
const MESSAGE_LOADING_THRESHOLD: usize = 50;
@@ -31,6 +31,7 @@ pub struct ChatPanel {
channel_select: ViewHandle<Select>,
settings: watch::Receiver<Settings>,
local_timezone: UtcOffset,
+ _status_observer: Task<()>,
}
pub enum Event {}
@@ -98,6 +99,14 @@ impl ChatPanel {
cx.dispatch_action(LoadMoreMessages);
}
});
+ let _status_observer = cx.spawn(|this, mut cx| {
+ let mut status = rpc.status();
+ async move {
+ while let Some(_) = status.recv().await {
+ this.update(&mut cx, |_, cx| cx.notify());
+ }
+ }
+ });
let mut this = Self {
rpc,
@@ -108,6 +117,7 @@ impl ChatPanel {
channel_select,
settings,
local_timezone: cx.platform().local_timezone(),
+ _status_observer,
};
this.init_active_channel(cx);
@@ -153,6 +163,7 @@ impl ChatPanel {
if let Some(active_channel) = active_channel {
self.set_active_channel(active_channel, cx);
} else {
+ self.message_list.reset(0);
self.active_channel = None;
}
@@ -357,10 +368,6 @@ impl ChatPanel {
})
}
}
-
- fn is_signed_in(&self) -> bool {
- self.rpc.user_id().borrow().is_some()
- }
}
impl Entity for ChatPanel {
@@ -374,10 +381,9 @@ impl View for ChatPanel {
fn render(&mut self, cx: &mut RenderContext<Self>) -> ElementBox {
let theme = &self.settings.borrow().theme;
- let element = if self.is_signed_in() {
- self.render_channel()
- } else {
- self.render_sign_in_prompt(cx)
+ let element = match *self.rpc.status().borrow() {
+ rpc::Status::Connected { .. } => self.render_channel(),
+ _ => self.render_sign_in_prompt(cx),
};
ConstrainedBox::new(
Container::new(element)
@@ -389,7 +395,7 @@ impl View for ChatPanel {
}
fn on_focus(&mut self, cx: &mut ViewContext<Self>) {
- if self.is_signed_in() {
+ if matches!(*self.rpc.status().borrow(), rpc::Status::Connected { .. }) {
cx.focus(&self.input_editor);
}
}
@@ -31,9 +31,20 @@ pub struct Client {
state: RwLock<ClientState>,
}
+#[derive(Copy, Clone, Debug)]
+pub enum Status {
+ Disconnected,
+ Connecting,
+ ConnectionError,
+ Connected {
+ connection_id: ConnectionId,
+ user_id: u64,
+ },
+ ConnectionLost,
+}
+
struct ClientState {
- connection_id: Option<ConnectionId>,
- user_id: (watch::Sender<Option<u64>>, watch::Receiver<Option<u64>>),
+ status: (watch::Sender<Status>, watch::Receiver<Status>),
entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
model_handlers: HashMap<
(TypeId, u64),
@@ -44,8 +55,7 @@ struct ClientState {
impl Default for ClientState {
fn default() -> Self {
Self {
- connection_id: Default::default(),
- user_id: watch::channel(),
+ status: watch::channel_with(Status::Disconnected),
entity_id_extractors: Default::default(),
model_handlers: Default::default(),
}
@@ -80,8 +90,14 @@ impl Client {
})
}
- pub fn user_id(&self) -> watch::Receiver<Option<u64>> {
- self.state.read().user_id.1.clone()
+ pub fn status(&self) -> watch::Receiver<Status> {
+ self.state.read().status.1.clone()
+ }
+
+ async fn set_status(&self, status: Status) -> Result<()> {
+ let mut state = self.state.write();
+ state.status.0.send(status).await?;
+ Ok(())
}
pub fn subscribe_from_model<T, M, F>(
@@ -141,43 +157,64 @@ impl Client {
self: &Arc<Self>,
cx: &AsyncAppContext,
) -> anyhow::Result<()> {
- if self.state.read().connection_id.is_some() {
+ if matches!(
+ *self.status().borrow(),
+ Status::Connecting | Status::Connected { .. }
+ ) {
return Ok(());
}
let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
let user_id = user_id.parse::<u64>()?;
+
+ self.set_status(Status::Connecting).await?;
+ match self.connect(user_id, &access_token, cx).await {
+ Ok(()) => {
+ log::info!("connected to rpc address {}", *ZED_SERVER_URL);
+ Ok(())
+ }
+ Err(err) => {
+ self.set_status(Status::ConnectionError).await?;
+ Err(err)
+ }
+ }
+ }
+
+ async fn connect(
+ self: &Arc<Self>,
+ user_id: u64,
+ access_token: &str,
+ cx: &AsyncAppContext,
+ ) -> Result<()> {
let request =
Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
-
if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
let stream = smol::net::TcpStream::connect(host).await?;
let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
.await
.context("websocket handshake")?;
- self.add_connection(user_id, stream, cx).await?;
+ self.set_connection(user_id, stream, cx).await?;
+ Ok(())
} else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
let stream = smol::net::TcpStream::connect(host).await?;
let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
let (stream, _) = async_tungstenite::client_async(request, stream)
.await
.context("websocket handshake")?;
- self.add_connection(user_id, stream, cx).await?;
+ self.set_connection(user_id, stream, cx).await?;
+ Ok(())
} else {
- return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?;
- };
-
- log::info!("connected to rpc address {}", *ZED_SERVER_URL);
- Ok(())
+ return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL));
+ }
}
- pub async fn add_connection<Conn>(
+ pub async fn set_connection<Conn>(
self: &Arc<Self>,
user_id: u64,
conn: Conn,
cx: &AsyncAppContext,
- ) -> anyhow::Result<()>
+ ) -> Result<()>
where
Conn: 'static
+ futures::Sink<WebSocketMessage, Error = WebSocketError>
@@ -218,16 +255,28 @@ impl Client {
})
.detach();
}
- cx.background()
+
+ self.set_status(Status::Connected {
+ connection_id,
+ user_id,
+ })
+ .await?;
+
+ let handle_io = cx.background().spawn(handle_io);
+ let this = self.clone();
+ cx.foreground()
.spawn(async move {
- if let Err(error) = handle_io.await {
- log::error!("connection error: {:?}", error);
+ match handle_io.await {
+ Ok(()) => {
+ let _ = this.set_status(Status::Disconnected).await;
+ }
+ Err(err) => {
+ log::error!("connection error: {:?}", err);
+ let _ = this.set_status(Status::ConnectionLost).await;
+ }
}
})
.detach();
- let mut state = self.state.write();
- state.connection_id = Some(connection_id);
- state.user_id.0.send(Some(user_id)).await?;
Ok(())
}
@@ -316,14 +365,16 @@ impl Client {
pub async fn disconnect(&self) -> Result<()> {
let conn_id = self.connection_id()?;
self.peer.disconnect(conn_id).await;
+ self.set_status(Status::Disconnected).await?;
Ok(())
}
fn connection_id(&self) -> Result<ConnectionId> {
- self.state
- .read()
- .connection_id
- .ok_or_else(|| anyhow!("not connected"))
+ if let Status::Connected { connection_id, .. } = *self.status().borrow() {
+ Ok(connection_id)
+ } else {
+ Err(anyhow!("not connected"))
+ }
}
pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
@@ -127,6 +127,7 @@ impl Peer {
let mut writer = MessageStream::new(tx);
let mut reader = MessageStream::new(rx);
+ let this = self.clone();
let response_channels = connection.response_channels.clone();
let handle_io = async move {
loop {
@@ -147,6 +148,7 @@ impl Peer {
if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
if incoming_tx.send(envelope).await.is_err() {
response_channels.lock().await.clear();
+ this.connections.write().await.remove(&connection_id);
return Ok(())
}
} else {
@@ -158,6 +160,7 @@ impl Peer {
}
Err(error) => {
response_channels.lock().await.clear();
+ this.connections.write().await.remove(&connection_id);
Err(error).context("received invalid RPC message")?;
}
},
@@ -165,11 +168,13 @@ impl Peer {
Some(outgoing) => {
if let Err(result) = writer.write_message(&outgoing).await {
response_channels.lock().await.clear();
+ this.connections.write().await.remove(&connection_id);
Err(result).context("failed to write RPC message")?;
}
}
None => {
response_channels.lock().await.clear();
+ this.connections.write().await.remove(&connection_id);
return Ok(())
}
}