Detailed changes
@@ -998,7 +998,6 @@ dependencies = [
name = "clock"
version = "0.1.0"
dependencies = [
- "rpc",
"smallvec",
]
@@ -2236,6 +2235,7 @@ dependencies = [
"tiny-skia",
"tree-sitter",
"usvg",
+ "util",
"waker-fn",
]
@@ -3959,6 +3959,7 @@ dependencies = [
"async-lock",
"async-tungstenite",
"base64 0.13.0",
+ "clock",
"futures",
"gpui",
"log",
@@ -3972,6 +3973,7 @@ dependencies = [
"smol",
"smol-timeout",
"tempdir",
+ "util",
"zstd",
]
@@ -5574,7 +5576,6 @@ name = "util"
version = "0.1.0"
dependencies = [
"anyhow",
- "clock",
"futures",
"log",
"rand 0.8.3",
@@ -5959,6 +5960,7 @@ name = "zed-server"
version = "0.1.0"
dependencies = [
"anyhow",
+ "async-io",
"async-sqlx-session",
"async-std",
"async-trait",
@@ -137,8 +137,8 @@ struct ClientState {
credentials: Option<Credentials>,
status: (watch::Sender<Status>, watch::Receiver<Status>),
entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
- _maintain_connection: Option<Task<()>>,
- heartbeat_interval: Duration,
+ _reconnect_task: Option<Task<()>>,
+ reconnect_interval: Duration,
models_by_entity_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakModelHandle>,
models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
model_types_by_message_type: HashMap<TypeId, TypeId>,
@@ -168,8 +168,8 @@ impl Default for ClientState {
credentials: None,
status: watch::channel_with(Status::SignedOut),
entity_id_extractors: Default::default(),
- _maintain_connection: None,
- heartbeat_interval: Duration::from_secs(5),
+ _reconnect_task: None,
+ reconnect_interval: Duration::from_secs(5),
models_by_message_type: Default::default(),
models_by_entity_type_and_remote_id: Default::default(),
model_types_by_message_type: Default::default(),
@@ -236,7 +236,7 @@ impl Client {
#[cfg(any(test, feature = "test-support"))]
pub fn tear_down(&self) {
let mut state = self.state.write();
- state._maintain_connection.take();
+ state._reconnect_task.take();
state.message_handlers.clear();
state.models_by_message_type.clear();
state.models_by_entity_type_and_remote_id.clear();
@@ -283,21 +283,13 @@ impl Client {
match status {
Status::Connected { .. } => {
- let heartbeat_interval = state.heartbeat_interval;
- let this = self.clone();
- let foreground = cx.foreground();
- state._maintain_connection = Some(cx.foreground().spawn(async move {
- loop {
- foreground.timer(heartbeat_interval).await;
- let _ = this.request(proto::Ping {}).await;
- }
- }));
+ state._reconnect_task = None;
}
Status::ConnectionLost => {
let this = self.clone();
let foreground = cx.foreground();
- let heartbeat_interval = state.heartbeat_interval;
- state._maintain_connection = Some(cx.spawn(|cx| async move {
+ let reconnect_interval = state.reconnect_interval;
+ state._reconnect_task = Some(cx.spawn(|cx| async move {
let mut rng = StdRng::from_entropy();
let mut delay = Duration::from_millis(100);
while let Err(error) = this.authenticate_and_connect(&cx).await {
@@ -311,12 +303,12 @@ impl Client {
foreground.timer(delay).await;
delay = delay
.mul_f32(rng.gen_range(1.0..=2.0))
- .min(heartbeat_interval);
+ .min(reconnect_interval);
}
}));
}
Status::SignedOut | Status::UpgradeRequired => {
- state._maintain_connection.take();
+ state._reconnect_task.take();
}
_ => {}
}
@@ -548,7 +540,11 @@ impl Client {
}
async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
- let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
+ let executor = cx.background();
+ let (connection_id, handle_io, mut incoming) = self
+ .peer
+ .add_connection(conn, move |duration| executor.timer(duration))
+ .await;
cx.foreground()
.spawn({
let cx = cx.clone();
@@ -940,26 +936,6 @@ mod tests {
use crate::test::{FakeHttpClient, FakeServer};
use gpui::TestAppContext;
- #[gpui::test(iterations = 10)]
- async fn test_heartbeat(cx: &mut TestAppContext) {
- cx.foreground().forbid_parking();
-
- let user_id = 5;
- let mut client = Client::new(FakeHttpClient::with_404_response());
- let server = FakeServer::for_client(user_id, &mut client, &cx).await;
-
- cx.foreground().advance_clock(Duration::from_secs(10));
- let ping = server.receive::<proto::Ping>().await.unwrap();
- server.respond(ping.receipt(), proto::Ack {}).await;
-
- cx.foreground().advance_clock(Duration::from_secs(10));
- let ping = server.receive::<proto::Ping>().await.unwrap();
- server.respond(ping.receipt(), proto::Ack {}).await;
-
- client.disconnect(&cx.to_async()).unwrap();
- assert!(server.receive::<proto::Ping>().await.is_err());
- }
-
#[gpui::test(iterations = 10)]
async fn test_reconnection(cx: &mut TestAppContext) {
cx.foreground().forbid_parking();
@@ -75,7 +75,8 @@ impl FakeServer {
}
let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
- let (connection_id, io, incoming) = peer.add_connection(server_conn).await;
+ let (connection_id, io, incoming) =
+ peer.add_test_connection(server_conn, cx.background()).await;
cx.background().spawn(io).detach();
let mut state = state.lock();
state.connection_id = Some(connection_id);
@@ -14,6 +14,7 @@ test-support = ["backtrace", "dhat", "env_logger", "collections/test-support"]
[dependencies]
collections = { path = "../collections" }
gpui_macros = { path = "../gpui_macros" }
+util = { path = "../util" }
sum_tree = { path = "../sum_tree" }
async-task = "4.0.3"
backtrace = { version = "0.3", optional = true }
@@ -26,7 +26,9 @@ rsa = "0.4"
serde = { version = "1", features = ["derive"] }
smol-timeout = "0.6"
zstd = "0.9"
+clock = { path = "../clock" }
gpui = { path = "../gpui", optional = true }
+util = { path = "../util" }
[build-dependencies]
prost-build = "0.8"
@@ -94,6 +94,7 @@ pub struct ConnectionState {
Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, barrier::Sender)>>>>>,
}
+const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(2);
const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
impl Peer {
@@ -104,14 +105,20 @@ impl Peer {
})
}
- pub async fn add_connection(
+ pub async fn add_connection<F, Fut, Out>(
self: &Arc<Self>,
connection: Connection,
+ create_timer: F,
) -> (
ConnectionId,
impl Future<Output = anyhow::Result<()>> + Send,
BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
- ) {
+ )
+ where
+ F: Send + Fn(Duration) -> Fut,
+ Fut: Send + Future<Output = Out>,
+ Out: Send,
+ {
// For outgoing messages, use an unbounded channel so that application code
// can always send messages without yielding. For incoming messages, use a
// bounded channel so that other peers will receive backpressure if they send
@@ -121,7 +128,7 @@ impl Peer {
let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
let connection_state = ConnectionState {
- outgoing_tx,
+ outgoing_tx: outgoing_tx.clone(),
next_message_id: Default::default(),
response_channels: Arc::new(Mutex::new(Some(Default::default()))),
};
@@ -131,39 +138,43 @@ impl Peer {
let this = self.clone();
let response_channels = connection_state.response_channels.clone();
let handle_io = async move {
- let result = 'outer: loop {
+ let _end_connection = util::defer(|| {
+ response_channels.lock().take();
+ this.connections.write().remove(&connection_id);
+ });
+
+ loop {
let read_message = reader.read_message().fuse();
futures::pin_mut!(read_message);
loop {
futures::select_biased! {
outgoing = outgoing_rx.next().fuse() => match outgoing {
Some(outgoing) => {
- match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
- None => break 'outer Err(anyhow!("timed out writing RPC message")),
- Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
- _ => {}
+ if let Some(result) = writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
+ result.context("failed to write RPC message")?;
+ } else {
+ Err(anyhow!("timed out writing message"))?;
}
}
- None => break 'outer Ok(()),
+ None => return Ok(()),
},
- incoming = read_message => match incoming {
- Ok(incoming) => {
- if incoming_tx.send(incoming).await.is_err() {
- break 'outer Ok(());
- }
- break;
- }
- Err(error) => {
- break 'outer Err(error).context("received invalid RPC message")
+ incoming = read_message => {
+ let incoming = incoming.context("received invalid rpc message")?;
+ if incoming_tx.send(incoming).await.is_err() {
+ return Ok(());
}
+ break;
},
+ _ = create_timer(KEEPALIVE_INTERVAL).fuse() => {
+ if let Some(result) = writer.ping().timeout(WRITE_TIMEOUT).await {
+ result.context("failed to send websocket ping")?;
+ } else {
+ Err(anyhow!("timed out sending websocket ping"))?;
+ }
+ }
}
}
- };
-
- response_channels.lock().take();
- this.connections.write().remove(&connection_id);
- result
+ }
};
let response_channels = connection_state.response_channels.clone();
@@ -191,18 +202,31 @@ impl Peer {
None
} else {
- if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
- Some(envelope)
- } else {
+ proto::build_typed_envelope(connection_id, incoming).or_else(|| {
log::error!("unable to construct a typed envelope");
None
- }
+ })
}
}
});
(connection_id, handle_io, incoming_rx.boxed())
}
+ #[cfg(any(test, feature = "test-support"))]
+ pub async fn add_test_connection(
+ self: &Arc<Self>,
+ connection: Connection,
+ executor: Arc<gpui::executor::Background>,
+ ) -> (
+ ConnectionId,
+ impl Future<Output = anyhow::Result<()>> + Send,
+ BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
+ ) {
+ let executor = executor.clone();
+ self.add_connection(connection, move |duration| executor.timer(duration))
+ .await
+ }
+
pub fn disconnect(&self, connection_id: ConnectionId) {
self.connections.write().remove(&connection_id);
}
@@ -349,15 +373,21 @@ mod tests {
let (client1_to_server_conn, server_to_client_1_conn, _) =
Connection::in_memory(cx.background());
- let (client1_conn_id, io_task1, client1_incoming) =
- client1.add_connection(client1_to_server_conn).await;
- let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await;
+ let (client1_conn_id, io_task1, client1_incoming) = client1
+ .add_test_connection(client1_to_server_conn, cx.background())
+ .await;
+ let (_, io_task2, server_incoming1) = server
+ .add_test_connection(server_to_client_1_conn, cx.background())
+ .await;
let (client2_to_server_conn, server_to_client_2_conn, _) =
Connection::in_memory(cx.background());
- let (client2_conn_id, io_task3, client2_incoming) =
- client2.add_connection(client2_to_server_conn).await;
- let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await;
+ let (client2_conn_id, io_task3, client2_incoming) = client2
+ .add_test_connection(client2_to_server_conn, cx.background())
+ .await;
+ let (_, io_task4, server_incoming2) = server
+ .add_test_connection(server_to_client_2_conn, cx.background())
+ .await;
executor.spawn(io_task1).detach();
executor.spawn(io_task2).detach();
@@ -440,10 +470,12 @@ mod tests {
let (client_to_server_conn, server_to_client_conn, _) =
Connection::in_memory(cx.background());
- let (client_to_server_conn_id, io_task1, mut client_incoming) =
- client.add_connection(client_to_server_conn).await;
- let (server_to_client_conn_id, io_task2, mut server_incoming) =
- server.add_connection(server_to_client_conn).await;
+ let (client_to_server_conn_id, io_task1, mut client_incoming) = client
+ .add_test_connection(client_to_server_conn, cx.background())
+ .await;
+ let (server_to_client_conn_id, io_task2, mut server_incoming) = server
+ .add_test_connection(server_to_client_conn, cx.background())
+ .await;
executor.spawn(io_task1).detach();
executor.spawn(io_task2).detach();
@@ -538,10 +570,12 @@ mod tests {
let (client_to_server_conn, server_to_client_conn, _) =
Connection::in_memory(cx.background());
- let (client_to_server_conn_id, io_task1, mut client_incoming) =
- client.add_connection(client_to_server_conn).await;
- let (server_to_client_conn_id, io_task2, mut server_incoming) =
- server.add_connection(server_to_client_conn).await;
+ let (client_to_server_conn_id, io_task1, mut client_incoming) = client
+ .add_test_connection(client_to_server_conn, cx.background())
+ .await;
+ let (server_to_client_conn_id, io_task2, mut server_incoming) = server
+ .add_test_connection(server_to_client_conn, cx.background())
+ .await;
executor.spawn(io_task1).detach();
executor.spawn(io_task2).detach();
@@ -649,7 +683,9 @@ mod tests {
let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
let client = Peer::new();
- let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
+ let (connection_id, io_handler, mut incoming) = client
+ .add_test_connection(client_conn, cx.background())
+ .await;
let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
executor
@@ -683,7 +719,9 @@ mod tests {
let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
let client = Peer::new();
- let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
+ let (connection_id, io_handler, mut incoming) = client
+ .add_test_connection(client_conn, cx.background())
+ .await;
executor.spawn(io_handler).detach();
executor
.spawn(async move { incoming.next().await })
@@ -318,6 +318,13 @@ where
self.stream.send(WebSocketMessage::Binary(buffer)).await?;
Ok(())
}
+
+ pub async fn ping(&mut self) -> Result<(), WebSocketError> {
+ self.stream
+ .send(WebSocketMessage::Ping(Default::default()))
+ .await?;
+ Ok(())
+ }
}
impl<S> MessageStream<S>
@@ -16,6 +16,7 @@ required-features = ["seed-support"]
collections = { path = "../collections" }
rpc = { path = "../rpc" }
anyhow = "1.0.40"
+async-io = "1.3"
async-std = { version = "1.8.0", features = ["attributes"] }
async-trait = "0.1.50"
async-tungstenite = "0.16"
@@ -6,6 +6,7 @@ use super::{
AppState,
};
use anyhow::anyhow;
+use async_io::Timer;
use async_std::task;
use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
use collections::{HashMap, HashSet};
@@ -16,7 +17,12 @@ use rpc::{
Connection, ConnectionId, Peer, TypedEnvelope,
};
use sha1::{Digest as _, Sha1};
-use std::{any::TypeId, future::Future, sync::Arc, time::Instant};
+use std::{
+ any::TypeId,
+ future::Future,
+ sync::Arc,
+ time::{Duration, Instant},
+};
use store::{Store, Worktree};
use surf::StatusCode;
use tide::log;
@@ -40,10 +46,13 @@ pub struct Server {
notifications: Option<mpsc::UnboundedSender<()>>,
}
-pub trait Executor {
+pub trait Executor: Send + Clone {
+ type Timer: Send + Future;
fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
+ fn timer(&self, duration: Duration) -> Self::Timer;
}
+#[derive(Clone)]
pub struct RealExecutor;
const MESSAGE_COUNT_PER_PAGE: usize = 100;
@@ -167,8 +176,18 @@ impl Server {
) -> impl Future<Output = ()> {
let mut this = self.clone();
async move {
- let (connection_id, handle_io, mut incoming_rx) =
- this.peer.add_connection(connection).await;
+ let (connection_id, handle_io, mut incoming_rx) = this
+ .peer
+ .add_connection(connection, {
+ let executor = executor.clone();
+ move |duration| {
+ let timer = executor.timer(duration);
+ async move {
+ timer.await;
+ }
+ }
+ })
+ .await;
if let Some(send_connection_id) = send_connection_id.as_mut() {
let _ = send_connection_id.send(connection_id).await;
@@ -883,9 +902,15 @@ impl Server {
}
impl Executor for RealExecutor {
+ type Timer = Timer;
+
fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
task::spawn(future);
}
+
+ fn timer(&self, duration: Duration) -> Self::Timer {
+ Timer::after(duration)
+ }
}
fn broadcast<F>(
@@ -1759,7 +1784,7 @@ mod tests {
}
#[gpui::test(iterations = 10)]
- async fn test_peer_disconnection(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
+ async fn test_leaving_project(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
cx_a.foreground().forbid_parking();
let lang_registry = Arc::new(LanguageRegistry::new());
let fs = FakeFs::new(cx_a.background());
@@ -1817,16 +1842,39 @@ mod tests {
.await
.unwrap();
- // See that a guest has joined as client A.
+ // Client A sees that a guest has joined.
project_a
.condition(&cx_a, |p, _| p.collaborators().len() == 1)
.await;
- // Drop client B's connection and ensure client A observes client B leaving the worktree.
+ // Drop client B's connection and ensure client A observes client B leaving the project.
client_b.disconnect(&cx_b.to_async()).unwrap();
project_a
.condition(&cx_a, |p, _| p.collaborators().len() == 0)
.await;
+
+ // Rejoin the project as client B
+ let _project_b = Project::remote(
+ project_id,
+ client_b.clone(),
+ client_b.user_store.clone(),
+ lang_registry.clone(),
+ fs.clone(),
+ &mut cx_b.to_async(),
+ )
+ .await
+ .unwrap();
+
+ // Client A sees that a guest has re-joined.
+ project_a
+ .condition(&cx_a, |p, _| p.collaborators().len() == 1)
+ .await;
+
+ // Simulate connection loss for client B and ensure client A observes client B leaving the project.
+ server.disconnect_client(client_b.current_user_id(cx_b));
+ project_a
+ .condition(&cx_a, |p, _| p.collaborators().len() == 0)
+ .await;
}
#[gpui::test(iterations = 10)]
@@ -5031,9 +5079,15 @@ mod tests {
}
impl Executor for Arc<gpui::executor::Background> {
+ type Timer = BoxFuture<'static, ()>;
+
fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
self.spawn(future).detach();
}
+
+ fn timer(&self, duration: Duration) -> Self::Timer {
+ self.as_ref().timer(duration).boxed()
+ }
}
fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> {