Detailed changes
@@ -3,8 +3,9 @@ use async_task::Runnable;
pub use async_task::Task;
use backtrace::{Backtrace, BacktraceFmt, BytesOrWideString};
use parking_lot::Mutex;
+use postage::{barrier, prelude::Stream as _};
use rand::prelude::*;
-use smol::{channel, prelude::*, Executor};
+use smol::{channel, prelude::*, Executor, Timer};
use std::{
fmt::{self, Debug},
marker::PhantomData,
@@ -18,7 +19,7 @@ use std::{
},
task::{Context, Poll},
thread,
- time::Duration,
+ time::{Duration, Instant},
};
use waker_fn::waker_fn;
@@ -49,6 +50,8 @@ struct DeterministicState {
spawned_from_foreground: Vec<(Runnable, Backtrace)>,
forbid_parking: bool,
block_on_ticks: RangeInclusive<usize>,
+ now: Instant,
+ pending_sleeps: Vec<(Instant, barrier::Sender)>,
}
pub struct Deterministic {
@@ -67,6 +70,8 @@ impl Deterministic {
spawned_from_foreground: Default::default(),
forbid_parking: false,
block_on_ticks: 0..=1000,
+ now: Instant::now(),
+ pending_sleeps: Default::default(),
})),
parker: Default::default(),
}
@@ -407,6 +412,35 @@ impl Foreground {
}
}
+ pub async fn sleep(&self, duration: Duration) {
+ match self {
+ Self::Deterministic(executor) => {
+ let (tx, mut rx) = barrier::channel();
+ {
+ let mut state = executor.state.lock();
+ let wakeup_at = state.now + duration;
+ state.pending_sleeps.push((wakeup_at, tx));
+ }
+ rx.recv().await;
+ }
+ _ => {
+ Timer::after(duration).await;
+ }
+ }
+ }
+
+ pub fn advance_clock(&self, duration: Duration) {
+ match self {
+ Self::Deterministic(executor) => {
+ let mut state = executor.state.lock();
+ state.now += duration;
+ let now = state.now;
+ state.pending_sleeps.retain(|(wakeup, _)| *wakeup > now);
+ }
+ _ => panic!("this method can only be called on a deterministic executor"),
+ }
+ }
+
pub fn set_block_on_ticks(&self, range: RangeInclusive<usize>) {
match self {
Self::Deterministic(executor) => executor.state.lock().block_on_ticks = range,
@@ -1469,7 +1469,7 @@ mod tests {
.await;
// Drop client B's connection and ensure client A observes client B leaving the worktree.
- client_b.disconnect().await.unwrap();
+ client_b.disconnect(&cx_b.to_async()).await.unwrap();
worktree_a
.condition(&cx_a, |tree, _| tree.peers().len() == 0)
.await;
@@ -443,9 +443,8 @@ impl<'a> sum_tree::SeekDimension<'a, ChannelMessageSummary> for Count {
#[cfg(test)]
mod tests {
use super::*;
+ use crate::test::FakeServer;
use gpui::TestAppContext;
- use postage::mpsc::Receiver;
- use zrpc::{test::Channel, ConnectionId, Peer, Receipt};
#[gpui::test]
async fn test_channel_messages(mut cx: TestAppContext) {
@@ -458,7 +457,7 @@ mod tests {
channel_list.read_with(&cx, |list, _| assert_eq!(list.available_channels(), None));
// Get the available channels.
- let get_channels = server.receive::<proto::GetChannels>().await;
+ let get_channels = server.receive::<proto::GetChannels>().await.unwrap();
server
.respond(
get_channels.receipt(),
@@ -489,7 +488,7 @@ mod tests {
})
.unwrap();
channel.read_with(&cx, |channel, _| assert!(channel.messages().is_empty()));
- let join_channel = server.receive::<proto::JoinChannel>().await;
+ let join_channel = server.receive::<proto::JoinChannel>().await.unwrap();
server
.respond(
join_channel.receipt(),
@@ -514,7 +513,7 @@ mod tests {
.await;
// Client requests all users for the received messages
- let mut get_users = server.receive::<proto::GetUsers>().await;
+ let mut get_users = server.receive::<proto::GetUsers>().await.unwrap();
get_users.payload.user_ids.sort();
assert_eq!(get_users.payload.user_ids, vec![5, 6]);
server
@@ -571,7 +570,7 @@ mod tests {
.await;
// Client requests user for message since they haven't seen them yet
- let get_users = server.receive::<proto::GetUsers>().await;
+ let get_users = server.receive::<proto::GetUsers>().await.unwrap();
assert_eq!(get_users.payload.user_ids, vec![7]);
server
.respond(
@@ -607,7 +606,7 @@ mod tests {
channel.update(&mut cx, |channel, cx| {
assert!(channel.load_more_messages(cx));
});
- let get_messages = server.receive::<proto::GetChannelMessages>().await;
+ let get_messages = server.receive::<proto::GetChannelMessages>().await.unwrap();
assert_eq!(get_messages.payload.channel_id, 5);
assert_eq!(get_messages.payload.before_message_id, 10);
server
@@ -653,53 +652,4 @@ mod tests {
);
});
}
-
- struct FakeServer {
- peer: Arc<Peer>,
- incoming: Receiver<Box<dyn proto::AnyTypedEnvelope>>,
- connection_id: ConnectionId,
- }
-
- impl FakeServer {
- async fn for_client(user_id: u64, client: &Arc<Client>, cx: &TestAppContext) -> Self {
- let (client_conn, server_conn) = Channel::bidirectional();
- let peer = Peer::new();
- let (connection_id, io, incoming) = peer.add_connection(server_conn).await;
- cx.background().spawn(io).detach();
-
- client
- .set_connection(user_id, client_conn, &cx.to_async())
- .await
- .unwrap();
-
- Self {
- peer,
- incoming,
- connection_id,
- }
- }
-
- async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
- self.peer.send(self.connection_id, message).await.unwrap();
- }
-
- async fn receive<M: proto::EnvelopedMessage>(&mut self) -> TypedEnvelope<M> {
- *self
- .incoming
- .recv()
- .await
- .unwrap()
- .into_any()
- .downcast::<TypedEnvelope<M>>()
- .unwrap()
- }
-
- async fn respond<T: proto::RequestMessage>(
- &self,
- receipt: Receipt<T>,
- response: T::Response,
- ) {
- self.peer.respond(receipt, response).await.unwrap()
- }
- }
}
@@ -3,10 +3,12 @@ use anyhow::{anyhow, Context, Result};
use async_tungstenite::tungstenite::{
http::Request, Error as WebSocketError, Message as WebSocketMessage,
};
+use futures::StreamExt as _;
use gpui::{AsyncAppContext, Entity, ModelContext, Task};
use lazy_static::lazy_static;
use parking_lot::RwLock;
use postage::{prelude::Stream, watch};
+use smol::Timer;
use std::{
any::TypeId,
collections::HashMap,
@@ -42,6 +44,10 @@ pub enum Status {
user_id: u64,
},
ConnectionLost,
+ Reconnecting,
+ ReconnectionError {
+ next_reconnection: Instant,
+ },
}
struct ClientState {
@@ -51,6 +57,8 @@ struct ClientState {
(TypeId, u64),
Box<dyn Send + Sync + FnMut(Box<dyn AnyTypedEnvelope>, &mut AsyncAppContext)>,
>,
+ _maintain_connection: Option<Task<()>>,
+ heartbeat_interval: Duration,
}
impl Default for ClientState {
@@ -59,6 +67,8 @@ impl Default for ClientState {
status: watch::channel_with(Status::Disconnected),
entity_id_extractors: Default::default(),
model_handlers: Default::default(),
+ _maintain_connection: None,
+ heartbeat_interval: Duration::from_secs(5),
}
}
}
@@ -95,9 +105,35 @@ impl Client {
self.state.read().status.1.clone()
}
- fn set_status(&self, status: Status) {
+ fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
let mut state = self.state.write();
*state.status.0.borrow_mut() = status;
+ match status {
+ Status::Connected { .. } => {
+ let heartbeat_interval = state.heartbeat_interval;
+ let this = self.clone();
+ let foreground = cx.foreground();
+ state._maintain_connection = Some(cx.foreground().spawn(async move {
+ let mut next_ping_id = 0;
+ loop {
+ foreground.sleep(heartbeat_interval).await;
+ this.request(proto::Ping { id: next_ping_id })
+ .await
+ .unwrap();
+ next_ping_id += 1;
+ }
+ }));
+ }
+ Status::ConnectionLost => {
+ state._maintain_connection = Some(cx.foreground().spawn(async move {
+ // TODO: try to reconnect
+ }));
+ }
+ Status::Disconnected => {
+ state._maintain_connection.take();
+ }
+ _ => {}
+ }
}
pub fn subscribe_from_model<T, M, F>(
@@ -167,14 +203,14 @@ impl Client {
let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
let user_id = user_id.parse::<u64>()?;
- self.set_status(Status::Connecting);
+ self.set_status(Status::Connecting, cx);
match self.connect(user_id, &access_token, cx).await {
Ok(()) => {
log::info!("connected to rpc address {}", *ZED_SERVER_URL);
Ok(())
}
Err(err) => {
- self.set_status(Status::ConnectionError);
+ self.set_status(Status::ConnectionError, cx);
Err(err)
}
}
@@ -256,20 +292,24 @@ impl Client {
.detach();
}
- self.set_status(Status::Connected {
- connection_id,
- user_id,
- });
+ self.set_status(
+ Status::Connected {
+ connection_id,
+ user_id,
+ },
+ cx,
+ );
let handle_io = cx.background().spawn(handle_io);
let this = self.clone();
+ let cx = cx.clone();
cx.foreground()
.spawn(async move {
match handle_io.await {
- Ok(()) => this.set_status(Status::Disconnected),
+ Ok(()) => this.set_status(Status::Disconnected, &cx),
Err(err) => {
log::error!("connection error: {:?}", err);
- this.set_status(Status::ConnectionLost);
+ this.set_status(Status::ConnectionLost, &cx);
}
}
})
@@ -359,10 +399,10 @@ impl Client {
})
}
- pub async fn disconnect(&self) -> Result<()> {
+ pub async fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
let conn_id = self.connection_id()?;
self.peer.disconnect(conn_id).await;
- self.set_status(Status::Disconnected);
+ self.set_status(Status::Disconnected, cx);
Ok(())
}
@@ -444,13 +484,40 @@ const LOGIN_RESPONSE: &'static str = "
</html>
";
-#[test]
-fn test_encode_and_decode_worktree_url() {
- let url = encode_worktree_url(5, "deadbeef");
- assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
- assert_eq!(
- decode_worktree_url(&format!("\n {}\t", url)),
- Some((5, "deadbeef".to_string()))
- );
- assert_eq!(decode_worktree_url("not://the-right-format"), None);
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::test::FakeServer;
+ use gpui::TestAppContext;
+
+ #[gpui::test(iterations = 1000)]
+ async fn test_heartbeat(cx: TestAppContext) {
+ let user_id = 5;
+ let client = Client::new();
+
+ client.state.write().heartbeat_interval = Duration::from_millis(1);
+ let mut server = FakeServer::for_client(user_id, &client, &cx).await;
+
+ let ping = server.receive::<proto::Ping>().await.unwrap();
+ assert_eq!(ping.payload.id, 0);
+ server.respond(ping.receipt(), proto::Pong { id: 0 }).await;
+
+ let ping = server.receive::<proto::Ping>().await.unwrap();
+ assert_eq!(ping.payload.id, 1);
+ server.respond(ping.receipt(), proto::Pong { id: 1 }).await;
+
+ client.disconnect(&cx.to_async()).await.unwrap();
+ assert!(server.receive::<proto::Ping>().await.is_err());
+ }
+
+ #[test]
+ fn test_encode_and_decode_worktree_url() {
+ let url = encode_worktree_url(5, "deadbeef");
+ assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
+ assert_eq!(
+ decode_worktree_url(&format!("\n {}\t", url)),
+ Some((5, "deadbeef".to_string()))
+ );
+ assert_eq!(decode_worktree_url("not://the-right-format"), None);
+ }
}
@@ -3,14 +3,16 @@ use crate::{
channel::ChannelList,
fs::RealFs,
language::LanguageRegistry,
- rpc,
+ rpc::{self, Client},
settings::{self, ThemeRegistry},
time::ReplicaId,
user::UserStore,
AppState,
};
-use gpui::{Entity, ModelHandle, MutableAppContext};
+use anyhow::{anyhow, Result};
+use gpui::{Entity, ModelHandle, MutableAppContext, TestAppContext};
use parking_lot::Mutex;
+use postage::{mpsc, prelude::Stream as _};
use smol::channel;
use std::{
marker::PhantomData,
@@ -18,6 +20,7 @@ use std::{
sync::Arc,
};
use tempdir::TempDir;
+use zrpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
#[cfg(feature = "test-support")]
pub use zrpc::test::Channel;
@@ -195,3 +198,50 @@ impl<T: Entity> Observer<T> {
(observer, notify_rx)
}
}
+
+pub struct FakeServer {
+ peer: Arc<Peer>,
+ incoming: mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>,
+ connection_id: ConnectionId,
+}
+
+impl FakeServer {
+ pub async fn for_client(user_id: u64, client: &Arc<Client>, cx: &TestAppContext) -> Self {
+ let (client_conn, server_conn) = zrpc::test::Channel::bidirectional();
+ let peer = Peer::new();
+ let (connection_id, io, incoming) = peer.add_connection(server_conn).await;
+ cx.background().spawn(io).detach();
+
+ client
+ .set_connection(user_id, client_conn, &cx.to_async())
+ .await
+ .unwrap();
+
+ Self {
+ peer,
+ incoming,
+ connection_id,
+ }
+ }
+
+ pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
+ self.peer.send(self.connection_id, message).await.unwrap();
+ }
+
+ pub async fn receive<M: proto::EnvelopedMessage>(&mut self) -> Result<TypedEnvelope<M>> {
+ let message = self
+ .incoming
+ .recv()
+ .await
+ .ok_or_else(|| anyhow!("other half hung up"))?;
+ Ok(*message.into_any().downcast::<TypedEnvelope<M>>().unwrap())
+ }
+
+ pub async fn respond<T: proto::RequestMessage>(
+ &self,
+ receipt: Receipt<T>,
+ response: T::Response,
+ ) {
+ self.peer.respond(receipt, response).await.unwrap()
+ }
+}