Merge pull request #542 from zed-industries/guest-disconnections

Max Brunsfeld created

Send heartbeats in both directions so the server can detect when clients disconnect

Change summary

.zed.toml                         |   2 
Cargo.lock                        |   6 
crates/client/src/client.rs       |  59 ++-------
crates/client/src/test.rs         |   9 +
crates/clock/Cargo.toml           |   1 
crates/clock/src/clock.rs         |  31 -----
crates/gpui/Cargo.toml            |   1 
crates/gpui/src/executor.rs       | 147 +++++++++++++++++---------
crates/gpui/src/util.rs           |   1 
crates/language/src/proto.rs      |  45 +++++--
crates/language/src/tests.rs      |   3 
crates/project/src/lsp_command.rs |  26 ++--
crates/project/src/project.rs     |  26 ++--
crates/project/src/worktree.rs    |  15 +-
crates/rpc/Cargo.toml             |   2 
crates/rpc/src/conn.rs            | 124 ++++++++++-----------
crates/rpc/src/peer.rs            | 182 ++++++++++++++++++++++----------
crates/rpc/src/proto.rs           |  51 ++++++--
crates/server/Cargo.toml          |   1 
crates/server/src/rpc.rs          |  80 +++++++++++--
crates/text/src/network.rs        |  69 ++++++++++++
crates/text/src/tests.rs          |   3 
crates/text/src/text.rs           |   2 
crates/util/Cargo.toml            |   3 
crates/util/src/lib.rs            |  12 ++
crates/util/src/test.rs           |  69 ------------
26 files changed, 571 insertions(+), 399 deletions(-)

Detailed changes

.zed.toml 🔗

@@ -1 +1 @@
-collaborators = ["nathansobo", "as-cii", "maxbrunsfeld", "iamnbutler"]
+collaborators = ["nathansobo", "as-cii", "maxbrunsfeld", "iamnbutler", "Kethku"]

Cargo.lock 🔗

@@ -998,7 +998,6 @@ dependencies = [
 name = "clock"
 version = "0.1.0"
 dependencies = [
- "rpc",
  "smallvec",
 ]
 
@@ -2236,6 +2235,7 @@ dependencies = [
  "tiny-skia",
  "tree-sitter",
  "usvg",
+ "util",
  "waker-fn",
 ]
 
@@ -3959,6 +3959,7 @@ dependencies = [
  "async-lock",
  "async-tungstenite",
  "base64 0.13.0",
+ "clock",
  "futures",
  "gpui",
  "log",
@@ -3972,6 +3973,7 @@ dependencies = [
  "smol",
  "smol-timeout",
  "tempdir",
+ "util",
  "zstd",
 ]
 
@@ -5574,7 +5576,6 @@ name = "util"
 version = "0.1.0"
 dependencies = [
  "anyhow",
- "clock",
  "futures",
  "log",
  "rand 0.8.3",
@@ -5959,6 +5960,7 @@ name = "zed-server"
 version = "0.1.0"
 dependencies = [
  "anyhow",
+ "async-io",
  "async-sqlx-session",
  "async-std",
  "async-trait",

crates/client/src/client.rs 🔗

@@ -137,8 +137,8 @@ struct ClientState {
     credentials: Option<Credentials>,
     status: (watch::Sender<Status>, watch::Receiver<Status>),
     entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
-    _maintain_connection: Option<Task<()>>,
-    heartbeat_interval: Duration,
+    _reconnect_task: Option<Task<()>>,
+    reconnect_interval: Duration,
     models_by_entity_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakModelHandle>,
     models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
     model_types_by_message_type: HashMap<TypeId, TypeId>,
@@ -168,8 +168,8 @@ impl Default for ClientState {
             credentials: None,
             status: watch::channel_with(Status::SignedOut),
             entity_id_extractors: Default::default(),
-            _maintain_connection: None,
-            heartbeat_interval: Duration::from_secs(5),
+            _reconnect_task: None,
+            reconnect_interval: Duration::from_secs(5),
             models_by_message_type: Default::default(),
             models_by_entity_type_and_remote_id: Default::default(),
             model_types_by_message_type: Default::default(),
@@ -236,7 +236,7 @@ impl Client {
     #[cfg(any(test, feature = "test-support"))]
     pub fn tear_down(&self) {
         let mut state = self.state.write();
-        state._maintain_connection.take();
+        state._reconnect_task.take();
         state.message_handlers.clear();
         state.models_by_message_type.clear();
         state.models_by_entity_type_and_remote_id.clear();
@@ -283,21 +283,12 @@ impl Client {
 
         match status {
             Status::Connected { .. } => {
-                let heartbeat_interval = state.heartbeat_interval;
-                let this = self.clone();
-                let foreground = cx.foreground();
-                state._maintain_connection = Some(cx.foreground().spawn(async move {
-                    loop {
-                        foreground.timer(heartbeat_interval).await;
-                        let _ = this.request(proto::Ping {}).await;
-                    }
-                }));
+                state._reconnect_task = None;
             }
             Status::ConnectionLost => {
                 let this = self.clone();
-                let foreground = cx.foreground();
-                let heartbeat_interval = state.heartbeat_interval;
-                state._maintain_connection = Some(cx.spawn(|cx| async move {
+                let reconnect_interval = state.reconnect_interval;
+                state._reconnect_task = Some(cx.spawn(|cx| async move {
                     let mut rng = StdRng::from_entropy();
                     let mut delay = Duration::from_millis(100);
                     while let Err(error) = this.authenticate_and_connect(&cx).await {
@@ -308,15 +299,15 @@ impl Client {
                             },
                             &cx,
                         );
-                        foreground.timer(delay).await;
+                        cx.background().timer(delay).await;
                         delay = delay
                             .mul_f32(rng.gen_range(1.0..=2.0))
-                            .min(heartbeat_interval);
+                            .min(reconnect_interval);
                     }
                 }));
             }
             Status::SignedOut | Status::UpgradeRequired => {
-                state._maintain_connection.take();
+                state._reconnect_task.take();
             }
             _ => {}
         }
@@ -548,7 +539,11 @@ impl Client {
     }
 
     async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
-        let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
+        let executor = cx.background();
+        let (connection_id, handle_io, mut incoming) = self
+            .peer
+            .add_connection(conn, move |duration| executor.timer(duration))
+            .await;
         cx.foreground()
             .spawn({
                 let cx = cx.clone();
@@ -940,26 +935,6 @@ mod tests {
     use crate::test::{FakeHttpClient, FakeServer};
     use gpui::TestAppContext;
 
-    #[gpui::test(iterations = 10)]
-    async fn test_heartbeat(cx: &mut TestAppContext) {
-        cx.foreground().forbid_parking();
-
-        let user_id = 5;
-        let mut client = Client::new(FakeHttpClient::with_404_response());
-        let server = FakeServer::for_client(user_id, &mut client, &cx).await;
-
-        cx.foreground().advance_clock(Duration::from_secs(10));
-        let ping = server.receive::<proto::Ping>().await.unwrap();
-        server.respond(ping.receipt(), proto::Ack {}).await;
-
-        cx.foreground().advance_clock(Duration::from_secs(10));
-        let ping = server.receive::<proto::Ping>().await.unwrap();
-        server.respond(ping.receipt(), proto::Ack {}).await;
-
-        client.disconnect(&cx.to_async()).unwrap();
-        assert!(server.receive::<proto::Ping>().await.is_err());
-    }
-
     #[gpui::test(iterations = 10)]
     async fn test_reconnection(cx: &mut TestAppContext) {
         cx.foreground().forbid_parking();
@@ -991,8 +966,6 @@ mod tests {
         server.roll_access_token();
         server.allow_connections();
         cx.foreground().advance_clock(Duration::from_secs(10));
-        assert_eq!(server.auth_count(), 1);
-        cx.foreground().advance_clock(Duration::from_secs(10));
         while !matches!(status.next().await, Some(Status::Connected { .. })) {}
         assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
     }

crates/client/src/test.rs 🔗

@@ -6,6 +6,7 @@ use anyhow::{anyhow, Result};
 use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt};
 use gpui::{executor, ModelHandle, TestAppContext};
 use parking_lot::Mutex;
+use postage::barrier;
 use rpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
 use std::{fmt, rc::Rc, sync::Arc};
 
@@ -22,6 +23,7 @@ struct FakeServerState {
     connection_id: Option<ConnectionId>,
     forbid_connections: bool,
     auth_count: usize,
+    connection_killer: Option<barrier::Sender>,
     access_token: usize,
 }
 
@@ -74,12 +76,15 @@ impl FakeServer {
                             Err(EstablishConnectionError::Unauthorized)?
                         }
 
-                        let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
-                        let (connection_id, io, incoming) = peer.add_connection(server_conn).await;
+                        let (client_conn, server_conn, kill) =
+                            Connection::in_memory(cx.background());
+                        let (connection_id, io, incoming) =
+                            peer.add_test_connection(server_conn, cx.background()).await;
                         cx.background().spawn(io).detach();
                         let mut state = state.lock();
                         state.connection_id = Some(connection_id);
                         state.incoming = Some(incoming);
+                        state.connection_killer = Some(kill);
                         Ok(client_conn)
                     })
                 }

crates/clock/Cargo.toml 🔗

@@ -9,4 +9,3 @@ doctest = false
 
 [dependencies]
 smallvec = { version = "1.6", features = ["union"] }
-rpc = { path = "../rpc" }

crates/clock/src/clock.rs 🔗

@@ -69,37 +69,6 @@ impl<'a> AddAssign<&'a Local> for Local {
 #[derive(Clone, Default, Hash, Eq, PartialEq)]
 pub struct Global(SmallVec<[u32; 8]>);
 
-impl From<Vec<rpc::proto::VectorClockEntry>> for Global {
-    fn from(message: Vec<rpc::proto::VectorClockEntry>) -> Self {
-        let mut version = Self::new();
-        for entry in message {
-            version.observe(Local {
-                replica_id: entry.replica_id as ReplicaId,
-                value: entry.timestamp,
-            });
-        }
-        version
-    }
-}
-
-impl<'a> From<&'a Global> for Vec<rpc::proto::VectorClockEntry> {
-    fn from(version: &'a Global) -> Self {
-        version
-            .iter()
-            .map(|entry| rpc::proto::VectorClockEntry {
-                replica_id: entry.replica_id as u32,
-                timestamp: entry.value,
-            })
-            .collect()
-    }
-}
-
-impl From<Global> for Vec<rpc::proto::VectorClockEntry> {
-    fn from(version: Global) -> Self {
-        (&version).into()
-    }
-}
-
 impl Global {
     pub fn new() -> Self {
         Self::default()

crates/gpui/Cargo.toml 🔗

@@ -14,6 +14,7 @@ test-support = ["backtrace", "dhat", "env_logger", "collections/test-support"]
 [dependencies]
 collections = { path = "../collections" }
 gpui_macros = { path = "../gpui_macros" }
+util = { path = "../util" }
 sum_tree = { path = "../sum_tree" }
 async-task = "4.0.3"
 backtrace = { version = "0.3", optional = true }

crates/gpui/src/executor.rs 🔗

@@ -1,6 +1,6 @@
 use anyhow::{anyhow, Result};
 use async_task::Runnable;
-use smol::{channel, prelude::*, Executor, Timer};
+use smol::{channel, prelude::*, Executor};
 use std::{
     any::Any,
     fmt::{self, Display},
@@ -86,6 +86,19 @@ pub struct Deterministic {
     parker: parking_lot::Mutex<parking::Parker>,
 }
 
+pub enum Timer {
+    Production(smol::Timer),
+    #[cfg(any(test, feature = "test-support"))]
+    Deterministic(DeterministicTimer),
+}
+
+#[cfg(any(test, feature = "test-support"))]
+pub struct DeterministicTimer {
+    rx: postage::barrier::Receiver,
+    id: usize,
+    state: Arc<parking_lot::Mutex<DeterministicState>>,
+}
+
 #[cfg(any(test, feature = "test-support"))]
 impl Deterministic {
     pub fn new(seed: u64) -> Arc<Self> {
@@ -306,15 +319,82 @@ impl Deterministic {
         None
     }
 
-    pub fn advance_clock(&self, duration: Duration) {
+    pub fn timer(&self, duration: Duration) -> Timer {
+        let (tx, rx) = postage::barrier::channel();
         let mut state = self.state.lock();
-        state.now += duration;
-        let now = state.now;
-        let mut pending_timers = mem::take(&mut state.pending_timers);
-        drop(state);
+        let wakeup_at = state.now + duration;
+        let id = util::post_inc(&mut state.next_timer_id);
+        state.pending_timers.push((id, wakeup_at, tx));
+        let state = self.state.clone();
+        Timer::Deterministic(DeterministicTimer { rx, id, state })
+    }
+
+    pub fn advance_clock(&self, duration: Duration) {
+        let new_now = self.state.lock().now + duration;
+        loop {
+            self.run_until_parked();
+            let mut state = self.state.lock();
 
-        pending_timers.retain(|(_, wakeup, _)| *wakeup > now);
-        self.state.lock().pending_timers.extend(pending_timers);
+            if let Some((_, wakeup_time, _)) = state.pending_timers.first() {
+                let wakeup_time = *wakeup_time;
+                if wakeup_time < new_now {
+                    let timer_count = state
+                        .pending_timers
+                        .iter()
+                        .take_while(|(_, t, _)| *t == wakeup_time)
+                        .count();
+                    state.now = wakeup_time;
+                    let timers_to_wake = state
+                        .pending_timers
+                        .drain(0..timer_count)
+                        .collect::<Vec<_>>();
+                    drop(state);
+                    drop(timers_to_wake);
+                    continue;
+                }
+            }
+
+            break;
+        }
+
+        self.state.lock().now = new_now;
+    }
+}
+
+impl Drop for Timer {
+    fn drop(&mut self) {
+        #[cfg(any(test, feature = "test-support"))]
+        if let Timer::Deterministic(DeterministicTimer { state, id, .. }) = self {
+            state
+                .lock()
+                .pending_timers
+                .retain(|(timer_id, _, _)| timer_id != id)
+        }
+    }
+}
+
+impl Future for Timer {
+    type Output = ();
+
+    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+        match &mut *self {
+            #[cfg(any(test, feature = "test-support"))]
+            Self::Deterministic(DeterministicTimer { rx, .. }) => {
+                use postage::stream::{PollRecv, Stream as _};
+                smol::pin!(rx);
+                match rx.poll_recv(&mut postage::Context::from_waker(cx.waker())) {
+                    PollRecv::Ready(()) | PollRecv::Closed => Poll::Ready(()),
+                    PollRecv::Pending => Poll::Pending,
+                }
+            }
+            Self::Production(timer) => {
+                smol::pin!(timer);
+                match timer.poll(cx) {
+                    Poll::Ready(_) => Poll::Ready(()),
+                    Poll::Pending => Poll::Pending,
+                }
+            }
+        }
     }
 }
 
@@ -438,46 +518,6 @@ impl Foreground {
         }
     }
 
-    pub async fn timer(&self, duration: Duration) {
-        match self {
-            #[cfg(any(test, feature = "test-support"))]
-            Self::Deterministic { executor, .. } => {
-                use postage::prelude::Stream as _;
-
-                let (tx, mut rx) = postage::barrier::channel();
-                let timer_id;
-                {
-                    let mut state = executor.state.lock();
-                    let wakeup_at = state.now + duration;
-                    timer_id = util::post_inc(&mut state.next_timer_id);
-                    state.pending_timers.push((timer_id, wakeup_at, tx));
-                }
-
-                struct DropTimer<'a>(usize, &'a Foreground);
-                impl<'a> Drop for DropTimer<'a> {
-                    fn drop(&mut self) {
-                        match self.1 {
-                            Foreground::Deterministic { executor, .. } => {
-                                executor
-                                    .state
-                                    .lock()
-                                    .pending_timers
-                                    .retain(|(timer_id, _, _)| *timer_id != self.0);
-                            }
-                            _ => unreachable!(),
-                        }
-                    }
-                }
-
-                let _guard = DropTimer(timer_id, self);
-                rx.recv().await;
-            }
-            _ => {
-                Timer::after(duration).await;
-            }
-        }
-    }
-
     #[cfg(any(test, feature = "test-support"))]
     pub fn advance_clock(&self, duration: Duration) {
         match self {
@@ -600,6 +640,14 @@ impl Background {
         }
     }
 
+    pub fn timer(&self, duration: Duration) -> Timer {
+        match self {
+            Background::Production { .. } => Timer::Production(smol::Timer::after(duration)),
+            #[cfg(any(test, feature = "test-support"))]
+            Background::Deterministic { executor } => executor.timer(duration),
+        }
+    }
+
     #[cfg(any(test, feature = "test-support"))]
     pub async fn simulate_random_delay(&self) {
         use rand::prelude::*;
@@ -612,9 +660,6 @@ impl Background {
                     for _ in 0..yields {
                         yield_now().await;
                     }
-
-                    let delay = Duration::from_millis(executor.state.lock().rng.gen_range(0..100));
-                    executor.advance_clock(delay);
                 }
             }
             _ => panic!("this method can only be called on a deterministic executor"),

crates/gpui/src/util.rs 🔗

@@ -1,5 +1,6 @@
 use smol::future::FutureExt;
 use std::{future::Future, time::Duration};
+pub use util::*;
 
 pub fn post_inc(value: &mut usize) -> usize {
     let prev = *value;

crates/language/src/proto.rs 🔗

@@ -25,13 +25,13 @@ pub fn serialize_operation(operation: &Operation) -> proto::Operation {
                 replica_id: undo.id.replica_id as u32,
                 local_timestamp: undo.id.value,
                 lamport_timestamp: lamport_timestamp.value,
-                version: From::from(&undo.version),
+                version: serialize_version(&undo.version),
                 transaction_ranges: undo
                     .transaction_ranges
                     .iter()
                     .map(serialize_range)
                     .collect(),
-                transaction_version: From::from(&undo.transaction_version),
+                transaction_version: serialize_version(&undo.transaction_version),
                 counts: undo
                     .counts
                     .iter()
@@ -77,7 +77,7 @@ pub fn serialize_edit_operation(operation: &EditOperation) -> proto::operation::
         replica_id: operation.timestamp.replica_id as u32,
         local_timestamp: operation.timestamp.local,
         lamport_timestamp: operation.timestamp.lamport,
-        version: From::from(&operation.version),
+        version: serialize_version(&operation.version),
         ranges: operation.ranges.iter().map(serialize_range).collect(),
         new_text: operation.new_text.clone(),
     }
@@ -116,7 +116,7 @@ pub fn serialize_buffer_fragment(fragment: &text::Fragment) -> proto::BufferFrag
                 timestamp: clock.value,
             })
             .collect(),
-        max_undos: From::from(&fragment.max_undos),
+        max_undos: serialize_version(&fragment.max_undos),
     }
 }
 
@@ -188,7 +188,7 @@ pub fn deserialize_operation(message: proto::Operation) -> Result<Operation> {
                         replica_id: undo.replica_id as ReplicaId,
                         value: undo.local_timestamp,
                     },
-                    version: undo.version.into(),
+                    version: deserialize_version(undo.version),
                     counts: undo
                         .counts
                         .into_iter()
@@ -207,7 +207,7 @@ pub fn deserialize_operation(message: proto::Operation) -> Result<Operation> {
                         .into_iter()
                         .map(deserialize_range)
                         .collect(),
-                    transaction_version: undo.transaction_version.into(),
+                    transaction_version: deserialize_version(undo.transaction_version),
                 },
             }),
             proto::operation::Variant::UpdateSelections(message) => {
@@ -260,7 +260,7 @@ pub fn deserialize_edit_operation(edit: proto::operation::Edit) -> EditOperation
             local: edit.local_timestamp,
             lamport: edit.lamport_timestamp,
         },
-        version: edit.version.into(),
+        version: deserialize_version(edit.version),
         ranges: edit.ranges.into_iter().map(deserialize_range).collect(),
         new_text: edit.new_text,
     }
@@ -309,7 +309,7 @@ pub fn deserialize_buffer_fragment(
             replica_id: entry.replica_id as ReplicaId,
             value: entry.timestamp,
         })),
-        max_undos: From::from(message.max_undos),
+        max_undos: deserialize_version(message.max_undos),
     }
 }
 
@@ -472,8 +472,8 @@ pub fn serialize_transaction(transaction: &Transaction) -> proto::Transaction {
             .copied()
             .map(serialize_local_timestamp)
             .collect(),
-        start: (&transaction.start).into(),
-        end: (&transaction.end).into(),
+        start: serialize_version(&transaction.start),
+        end: serialize_version(&transaction.end),
         ranges: transaction.ranges.iter().map(serialize_range).collect(),
     }
 }
@@ -490,8 +490,8 @@ pub fn deserialize_transaction(transaction: proto::Transaction) -> Result<Transa
             .into_iter()
             .map(deserialize_local_timestamp)
             .collect(),
-        start: transaction.start.into(),
-        end: transaction.end.into(),
+        start: deserialize_version(transaction.start.into()),
+        end: deserialize_version(transaction.end),
         ranges: transaction
             .ranges
             .into_iter()
@@ -524,3 +524,24 @@ pub fn serialize_range(range: &Range<FullOffset>) -> proto::Range {
 pub fn deserialize_range(range: proto::Range) -> Range<FullOffset> {
     FullOffset(range.start as usize)..FullOffset(range.end as usize)
 }
+
+pub fn deserialize_version(message: Vec<proto::VectorClockEntry>) -> clock::Global {
+    let mut version = clock::Global::new();
+    for entry in message {
+        version.observe(clock::Local {
+            replica_id: entry.replica_id as ReplicaId,
+            value: entry.timestamp,
+        });
+    }
+    version
+}
+
+pub fn serialize_version(version: &clock::Global) -> Vec<proto::VectorClockEntry> {
+    version
+        .iter()
+        .map(|entry| proto::VectorClockEntry {
+            replica_id: entry.replica_id as u32,
+            timestamp: entry.value,
+        })
+        .collect()
+}

crates/language/src/tests.rs 🔗

@@ -11,8 +11,9 @@ use std::{
     rc::Rc,
     time::{Duration, Instant},
 };
+use text::network::Network;
 use unindent::Unindent as _;
-use util::{post_inc, test::Network};
+use util::post_inc;
 
 #[cfg(test)]
 #[ctor::ctor]

crates/project/src/lsp_command.rs 🔗

@@ -5,7 +5,7 @@ use client::{proto, PeerId};
 use gpui::{AppContext, AsyncAppContext, ModelHandle};
 use language::{
     point_from_lsp,
-    proto::{deserialize_anchor, serialize_anchor},
+    proto::{deserialize_anchor, deserialize_version, serialize_anchor, serialize_version},
     range_from_lsp, Anchor, Bias, Buffer, PointUtf16, ToLspPosition, ToPointUtf16,
 };
 use lsp::{DocumentHighlightKind, ServerCapabilities};
@@ -126,7 +126,7 @@ impl LspCommand for PrepareRename {
             position: Some(language::proto::serialize_anchor(
                 &buffer.anchor_before(self.position),
             )),
-            version: (&buffer.version()).into(),
+            version: serialize_version(&buffer.version()),
         }
     }
 
@@ -142,7 +142,7 @@ impl LspCommand for PrepareRename {
             .ok_or_else(|| anyhow!("invalid position"))?;
         buffer
             .update(&mut cx, |buffer, _| {
-                buffer.wait_for_version(message.version.into())
+                buffer.wait_for_version(deserialize_version(message.version))
             })
             .await;
 
@@ -166,7 +166,7 @@ impl LspCommand for PrepareRename {
             end: range
                 .as_ref()
                 .map(|range| language::proto::serialize_anchor(&range.end)),
-            version: buffer_version.into(),
+            version: serialize_version(buffer_version),
         }
     }
 
@@ -180,7 +180,7 @@ impl LspCommand for PrepareRename {
         if message.can_rename {
             buffer
                 .update(&mut cx, |buffer, _| {
-                    buffer.wait_for_version(message.version.into())
+                    buffer.wait_for_version(deserialize_version(message.version))
                 })
                 .await;
             let start = message.start.and_then(deserialize_anchor);
@@ -255,7 +255,7 @@ impl LspCommand for PerformRename {
                 &buffer.anchor_before(self.position),
             )),
             new_name: self.new_name.clone(),
-            version: (&buffer.version()).into(),
+            version: serialize_version(&buffer.version()),
         }
     }
 
@@ -271,7 +271,7 @@ impl LspCommand for PerformRename {
             .ok_or_else(|| anyhow!("invalid position"))?;
         buffer
             .update(&mut cx, |buffer, _| {
-                buffer.wait_for_version(message.version.into())
+                buffer.wait_for_version(deserialize_version(message.version))
             })
             .await;
         Ok(Self {
@@ -407,7 +407,7 @@ impl LspCommand for GetDefinition {
             position: Some(language::proto::serialize_anchor(
                 &buffer.anchor_before(self.position),
             )),
-            version: (&buffer.version()).into(),
+            version: serialize_version(&buffer.version()),
         }
     }
 
@@ -423,7 +423,7 @@ impl LspCommand for GetDefinition {
             .ok_or_else(|| anyhow!("invalid position"))?;
         buffer
             .update(&mut cx, |buffer, _| {
-                buffer.wait_for_version(message.version.into())
+                buffer.wait_for_version(deserialize_version(message.version))
             })
             .await;
         Ok(Self {
@@ -566,7 +566,7 @@ impl LspCommand for GetReferences {
             position: Some(language::proto::serialize_anchor(
                 &buffer.anchor_before(self.position),
             )),
-            version: (&buffer.version()).into(),
+            version: serialize_version(&buffer.version()),
         }
     }
 
@@ -582,7 +582,7 @@ impl LspCommand for GetReferences {
             .ok_or_else(|| anyhow!("invalid position"))?;
         buffer
             .update(&mut cx, |buffer, _| {
-                buffer.wait_for_version(message.version.into())
+                buffer.wait_for_version(deserialize_version(message.version))
             })
             .await;
         Ok(Self {
@@ -706,7 +706,7 @@ impl LspCommand for GetDocumentHighlights {
             position: Some(language::proto::serialize_anchor(
                 &buffer.anchor_before(self.position),
             )),
-            version: (&buffer.version()).into(),
+            version: serialize_version(&buffer.version()),
         }
     }
 
@@ -722,7 +722,7 @@ impl LspCommand for GetDocumentHighlights {
             .ok_or_else(|| anyhow!("invalid position"))?;
         buffer
             .update(&mut cx, |buffer, _| {
-                buffer.wait_for_version(message.version.into())
+                buffer.wait_for_version(deserialize_version(message.version))
             })
             .await;
         Ok(Self {

crates/project/src/project.rs 🔗

@@ -15,7 +15,7 @@ use gpui::{
     UpgradeModelHandle, WeakModelHandle,
 };
 use language::{
-    proto::{deserialize_anchor, serialize_anchor},
+    proto::{deserialize_anchor, deserialize_version, serialize_anchor, serialize_version},
     range_from_lsp, Anchor, AnchorRangeExt, Bias, Buffer, CodeAction, CodeLabel, Completion,
     Diagnostic, DiagnosticEntry, File as _, Language, LanguageRegistry, Operation, PointUtf16,
     ToLspPosition, ToOffset, ToPointUtf16, Transaction,
@@ -1713,14 +1713,14 @@ impl Project {
                 project_id,
                 buffer_id,
                 position: Some(language::proto::serialize_anchor(&anchor)),
-                version: (&source_buffer.version()).into(),
+                version: serialize_version(&source_buffer.version()),
             };
             cx.spawn_weak(|_, mut cx| async move {
                 let response = rpc.request(message).await?;
 
                 source_buffer_handle
                     .update(&mut cx, |buffer, _| {
-                        buffer.wait_for_version(response.version.into())
+                        buffer.wait_for_version(deserialize_version(response.version))
                     })
                     .await;
 
@@ -1910,13 +1910,13 @@ impl Project {
                         buffer_id,
                         start: Some(language::proto::serialize_anchor(&range.start)),
                         end: Some(language::proto::serialize_anchor(&range.end)),
-                        version: (&version).into(),
+                        version: serialize_version(&version),
                     })
                     .await?;
 
                 buffer_handle
                     .update(&mut cx, |buffer, _| {
-                        buffer.wait_for_version(response.version.into())
+                        buffer.wait_for_version(deserialize_version(response.version))
                     })
                     .await;
 
@@ -2915,7 +2915,7 @@ impl Project {
         mut cx: AsyncAppContext,
     ) -> Result<proto::BufferSaved> {
         let buffer_id = envelope.payload.buffer_id;
-        let requested_version = envelope.payload.version.try_into()?;
+        let requested_version = deserialize_version(envelope.payload.version);
 
         let (project_id, buffer) = this.update(&mut cx, |this, cx| {
             let project_id = this.remote_id().ok_or_else(|| anyhow!("not connected"))?;
@@ -2936,7 +2936,7 @@ impl Project {
         Ok(proto::BufferSaved {
             project_id,
             buffer_id,
-            version: (&saved_version).into(),
+            version: serialize_version(&saved_version),
             mtime: Some(mtime.into()),
         })
     }
@@ -2981,7 +2981,7 @@ impl Project {
             .position
             .and_then(language::proto::deserialize_anchor)
             .ok_or_else(|| anyhow!("invalid position"))?;
-        let version = clock::Global::from(envelope.payload.version);
+        let version = deserialize_version(envelope.payload.version);
         let buffer = this.read_with(&cx, |this, cx| {
             this.opened_buffers
                 .get(&envelope.payload.buffer_id)
@@ -3001,7 +3001,7 @@ impl Project {
                 .iter()
                 .map(language::proto::serialize_completion)
                 .collect(),
-            version: (&version).into(),
+            version: serialize_version(&version),
         })
     }
 
@@ -3062,7 +3062,7 @@ impl Project {
         })?;
         buffer
             .update(&mut cx, |buffer, _| {
-                buffer.wait_for_version(envelope.payload.version.into())
+                buffer.wait_for_version(deserialize_version(envelope.payload.version))
             })
             .await;
 
@@ -3077,7 +3077,7 @@ impl Project {
                 .iter()
                 .map(language::proto::serialize_code_action)
                 .collect(),
-            version: (&version).into(),
+            version: serialize_version(&version),
         })
     }
 
@@ -3445,7 +3445,7 @@ impl Project {
         _: Arc<Client>,
         mut cx: AsyncAppContext,
     ) -> Result<()> {
-        let version = envelope.payload.version.try_into()?;
+        let version = deserialize_version(envelope.payload.version);
         let mtime = envelope
             .payload
             .mtime
@@ -3473,7 +3473,7 @@ impl Project {
         mut cx: AsyncAppContext,
     ) -> Result<()> {
         let payload = envelope.payload.clone();
-        let version = payload.version.try_into()?;
+        let version = deserialize_version(payload.version);
         let mtime = payload
             .mtime
             .ok_or_else(|| anyhow!("missing mtime"))?

crates/project/src/worktree.rs 🔗

@@ -17,7 +17,10 @@ use gpui::{
     executor, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
     Task,
 };
-use language::{Buffer, DiagnosticEntry, Operation, PointUtf16, Rope};
+use language::{
+    proto::{deserialize_version, serialize_version},
+    Buffer, DiagnosticEntry, Operation, PointUtf16, Rope,
+};
 use lazy_static::lazy_static;
 use parking_lot::Mutex;
 use postage::{
@@ -30,7 +33,7 @@ use smol::channel::{self, Sender};
 use std::{
     any::Any,
     cmp::{self, Ordering},
-    convert::{TryFrom, TryInto},
+    convert::TryFrom,
     ffi::{OsStr, OsString},
     fmt,
     future::Future,
@@ -1423,7 +1426,7 @@ impl language::File for File {
                         rpc.send(proto::BufferSaved {
                             project_id,
                             buffer_id,
-                            version: (&version).into(),
+                            version: serialize_version(&version),
                             mtime: Some(entry.mtime.into()),
                         })?;
                     }
@@ -1438,10 +1441,10 @@ impl language::File for File {
                         .request(proto::SaveBuffer {
                             project_id,
                             buffer_id,
-                            version: (&version).into(),
+                            version: serialize_version(&version),
                         })
                         .await?;
-                    let version = response.version.try_into()?;
+                    let version = deserialize_version(response.version);
                     let mtime = response
                         .mtime
                         .ok_or_else(|| anyhow!("missing mtime"))?
@@ -1518,7 +1521,7 @@ impl language::LocalFile for File {
                 .send(proto::BufferReloaded {
                     project_id,
                     buffer_id,
-                    version: version.into(),
+                    version: serialize_version(&version),
                     mtime: Some(mtime.into()),
                 })
                 .log_err();

crates/rpc/Cargo.toml 🔗

@@ -26,7 +26,9 @@ rsa = "0.4"
 serde = { version = "1", features = ["derive"] }
 smol-timeout = "0.6"
 zstd = "0.9"
+clock = { path = "../clock" }
 gpui = { path = "../gpui", optional = true }
+util = { path = "../util" }
 
 [build-dependencies]
 prost-build = "0.8"

crates/rpc/src/conn.rs 🔗

@@ -1,6 +1,5 @@
 use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
-use futures::{SinkExt as _, Stream, StreamExt as _};
-use std::{io, task::Poll};
+use futures::{SinkExt as _, StreamExt as _};
 
 pub struct Connection {
     pub(crate) tx:
@@ -36,87 +35,82 @@ impl Connection {
     #[cfg(any(test, feature = "test-support"))]
     pub fn in_memory(
         executor: std::sync::Arc<gpui::executor::Background>,
-    ) -> (Self, Self, postage::watch::Sender<Option<()>>) {
-        let (kill_tx, mut kill_rx) = postage::watch::channel_with(None);
-        postage::stream::Stream::try_recv(&mut kill_rx).unwrap();
+    ) -> (Self, Self, postage::barrier::Sender) {
+        use postage::prelude::Stream;
 
-        let (a_tx, a_rx) = Self::channel(kill_rx.clone(), executor.clone());
-        let (b_tx, b_rx) = Self::channel(kill_rx, executor);
-        (
+        let (kill_tx, kill_rx) = postage::barrier::channel();
+        let (a_tx, a_rx) = channel(kill_rx.clone(), executor.clone());
+        let (b_tx, b_rx) = channel(kill_rx, executor);
+        return (
             Self { tx: a_tx, rx: b_rx },
             Self { tx: b_tx, rx: a_rx },
             kill_tx,
-        )
-    }
+        );
 
-    #[cfg(any(test, feature = "test-support"))]
-    fn channel(
-        kill_rx: postage::watch::Receiver<Option<()>>,
-        executor: std::sync::Arc<gpui::executor::Background>,
-    ) -> (
-        Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
-        Box<dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>>,
-    ) {
-        use futures::channel::mpsc;
-        use io::{Error, ErrorKind};
-        use std::sync::Arc;
+        fn channel(
+            kill_rx: postage::barrier::Receiver,
+            executor: std::sync::Arc<gpui::executor::Background>,
+        ) -> (
+            Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
+            Box<
+                dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
+            >,
+        ) {
+            use futures::channel::mpsc;
+            use std::{
+                io::{Error, ErrorKind},
+                sync::Arc,
+            };
 
-        let (tx, rx) = mpsc::unbounded::<WebSocketMessage>();
-        let tx = tx
-            .sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e)))
-            .with({
-                let executor = Arc::downgrade(&executor);
+            let (tx, rx) = mpsc::unbounded::<WebSocketMessage>();
+
+            let tx = tx
+                .sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e)))
+                .with({
+                    let kill_rx = kill_rx.clone();
+                    let executor = Arc::downgrade(&executor);
+                    move |msg| {
+                        let mut kill_rx = kill_rx.clone();
+                        let executor = executor.clone();
+                        Box::pin(async move {
+                            if let Some(executor) = executor.upgrade() {
+                                executor.simulate_random_delay().await;
+                            }
+
+                            // Writes to a half-open TCP connection will error.
+                            if kill_rx.try_recv().is_ok() {
+                                std::io::Result::Err(
+                                    Error::new(ErrorKind::Other, "connection lost").into(),
+                                )?;
+                            }
+
+                            Ok(msg)
+                        })
+                    }
+                });
+
+            let rx = rx.then({
                 let kill_rx = kill_rx.clone();
+                let executor = Arc::downgrade(&executor);
                 move |msg| {
-                    let kill_rx = kill_rx.clone();
+                    let mut kill_rx = kill_rx.clone();
                     let executor = executor.clone();
                     Box::pin(async move {
                         if let Some(executor) = executor.upgrade() {
                             executor.simulate_random_delay().await;
                         }
-                        if kill_rx.borrow().is_none() {
-                            Ok(msg)
-                        } else {
-                            Err(Error::new(ErrorKind::Other, "connection killed").into())
+
+                        // Reads from a half-open TCP connection will hang.
+                        if kill_rx.try_recv().is_ok() {
+                            futures::future::pending::<()>().await;
                         }
+
+                        Ok(msg)
                     })
                 }
             });
-        let rx = rx.then(move |msg| {
-            let executor = Arc::downgrade(&executor);
-            Box::pin(async move {
-                if let Some(executor) = executor.upgrade() {
-                    executor.simulate_random_delay().await;
-                }
-                msg
-            })
-        });
-        let rx = KillableReceiver { kill_rx, rx };
-
-        (Box::new(tx), Box::new(rx))
-    }
-}
-
-struct KillableReceiver<S> {
-    rx: S,
-    kill_rx: postage::watch::Receiver<Option<()>>,
-}
-
-impl<S: Unpin + Stream<Item = WebSocketMessage>> Stream for KillableReceiver<S> {
-    type Item = Result<WebSocketMessage, WebSocketError>;
 
-    fn poll_next(
-        mut self: std::pin::Pin<&mut Self>,
-        cx: &mut std::task::Context<'_>,
-    ) -> Poll<Option<Self::Item>> {
-        if let Poll::Ready(Some(Some(()))) = self.kill_rx.poll_next_unpin(cx) {
-            Poll::Ready(Some(Err(io::Error::new(
-                io::ErrorKind::Other,
-                "connection killed",
-            )
-            .into())))
-        } else {
-            self.rx.poll_next_unpin(cx).map(|value| value.map(Ok))
+            (Box::new(tx), Box::new(rx))
         }
     }
 }

crates/rpc/src/peer.rs 🔗

@@ -88,13 +88,14 @@ pub struct Peer {
 
 #[derive(Clone)]
 pub struct ConnectionState {
-    outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Envelope>,
+    outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Message>,
     next_message_id: Arc<AtomicU32>,
     response_channels:
         Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, barrier::Sender)>>>>>,
 }
 
-const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
+const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
+const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
 
 impl Peer {
     pub fn new() -> Arc<Self> {
@@ -104,14 +105,20 @@ impl Peer {
         })
     }
 
-    pub async fn add_connection(
+    pub async fn add_connection<F, Fut, Out>(
         self: &Arc<Self>,
         connection: Connection,
+        create_timer: F,
     ) -> (
         ConnectionId,
         impl Future<Output = anyhow::Result<()>> + Send,
         BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
-    ) {
+    )
+    where
+        F: Send + Fn(Duration) -> Fut,
+        Fut: Send + Future<Output = Out>,
+        Out: Send,
+    {
         // For outgoing messages, use an unbounded channel so that application code
         // can always send messages without yielding. For incoming messages, use a
         // bounded channel so that other peers will receive backpressure if they send
@@ -121,7 +128,7 @@ impl Peer {
 
         let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
         let connection_state = ConnectionState {
-            outgoing_tx,
+            outgoing_tx: outgoing_tx.clone(),
             next_message_id: Default::default(),
             response_channels: Arc::new(Mutex::new(Some(Default::default()))),
         };
@@ -131,39 +138,59 @@ impl Peer {
         let this = self.clone();
         let response_channels = connection_state.response_channels.clone();
         let handle_io = async move {
-            let result = 'outer: loop {
-                let read_message = reader.read_message().fuse();
+            let _end_connection = util::defer(|| {
+                response_channels.lock().take();
+                this.connections.write().remove(&connection_id);
+            });
+
+            // Send messages on this frequency so the connection isn't closed.
+            let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
+            futures::pin_mut!(keepalive_timer);
+
+            loop {
+                let read_message = reader.read().fuse();
                 futures::pin_mut!(read_message);
+
+                // Disconnect if we don't receive messages at least this frequently.
+                let receive_timeout = create_timer(3 * KEEPALIVE_INTERVAL).fuse();
+                futures::pin_mut!(receive_timeout);
+
                 loop {
                     futures::select_biased! {
                         outgoing = outgoing_rx.next().fuse() => match outgoing {
                             Some(outgoing) => {
-                                match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
-                                    None => break 'outer Err(anyhow!("timed out writing RPC message")),
-                                    Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
-                                    _ => {}
+                                if let Some(result) = writer.write(outgoing).timeout(WRITE_TIMEOUT).await {
+                                    result.context("failed to write RPC message")?;
+                                    keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
+                                } else {
+                                    Err(anyhow!("timed out writing message"))?;
                                 }
                             }
-                            None => break 'outer Ok(()),
+                            None => return Ok(()),
                         },
-                        incoming = read_message => match incoming {
-                            Ok(incoming) => {
+                        incoming = read_message => {
+                            let incoming = incoming.context("received invalid RPC message")?;
+                            if let proto::Message::Envelope(incoming) = incoming {
                                 if incoming_tx.send(incoming).await.is_err() {
-                                    break 'outer Ok(());
+                                    return Ok(());
                                 }
-                                break;
-                            }
-                            Err(error) => {
-                                break 'outer Err(error).context("received invalid RPC message")
                             }
+                            break;
                         },
+                        _ = keepalive_timer => {
+                            if let Some(result) = writer.write(proto::Message::Ping).timeout(WRITE_TIMEOUT).await {
+                                result.context("failed to send keepalive")?;
+                                keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
+                            } else {
+                                Err(anyhow!("timed out sending keepalive"))?;
+                            }
+                        }
+                        _ = receive_timeout => {
+                            Err(anyhow!("delay between messages too long"))?
+                        }
                     }
                 }
-            };
-
-            response_channels.lock().take();
-            this.connections.write().remove(&connection_id);
-            result
+            }
         };
 
         let response_channels = connection_state.response_channels.clone();
@@ -191,18 +218,31 @@ impl Peer {
 
                     None
                 } else {
-                    if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
-                        Some(envelope)
-                    } else {
+                    proto::build_typed_envelope(connection_id, incoming).or_else(|| {
                         log::error!("unable to construct a typed envelope");
                         None
-                    }
+                    })
                 }
             }
         });
         (connection_id, handle_io, incoming_rx.boxed())
     }
 
+    #[cfg(any(test, feature = "test-support"))]
+    pub async fn add_test_connection(
+        self: &Arc<Self>,
+        connection: Connection,
+        executor: Arc<gpui::executor::Background>,
+    ) -> (
+        ConnectionId,
+        impl Future<Output = anyhow::Result<()>> + Send,
+        BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
+    ) {
+        let executor = executor.clone();
+        self.add_connection(connection, move |duration| executor.timer(duration))
+            .await
+    }
+
     pub fn disconnect(&self, connection_id: ConnectionId) {
         self.connections.write().remove(&connection_id);
     }
@@ -245,11 +285,11 @@ impl Peer {
                 .insert(message_id, tx);
             connection
                 .outgoing_tx
-                .unbounded_send(request.into_envelope(
+                .unbounded_send(proto::Message::Envelope(request.into_envelope(
                     message_id,
                     None,
                     original_sender_id.map(|id| id.0),
-                ))
+                )))
                 .map_err(|_| anyhow!("connection was closed"))?;
             Ok(())
         });
@@ -272,7 +312,9 @@ impl Peer {
             .fetch_add(1, atomic::Ordering::SeqCst);
         connection
             .outgoing_tx
-            .unbounded_send(message.into_envelope(message_id, None, None))?;
+            .unbounded_send(proto::Message::Envelope(
+                message.into_envelope(message_id, None, None),
+            ))?;
         Ok(())
     }
 
@@ -288,7 +330,11 @@ impl Peer {
             .fetch_add(1, atomic::Ordering::SeqCst);
         connection
             .outgoing_tx
-            .unbounded_send(message.into_envelope(message_id, None, Some(sender_id.0)))?;
+            .unbounded_send(proto::Message::Envelope(message.into_envelope(
+                message_id,
+                None,
+                Some(sender_id.0),
+            )))?;
         Ok(())
     }
 
@@ -303,7 +349,11 @@ impl Peer {
             .fetch_add(1, atomic::Ordering::SeqCst);
         connection
             .outgoing_tx
-            .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
+            .unbounded_send(proto::Message::Envelope(response.into_envelope(
+                message_id,
+                Some(receipt.message_id),
+                None,
+            )))?;
         Ok(())
     }
 
@@ -318,7 +368,11 @@ impl Peer {
             .fetch_add(1, atomic::Ordering::SeqCst);
         connection
             .outgoing_tx
-            .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
+            .unbounded_send(proto::Message::Envelope(response.into_envelope(
+                message_id,
+                Some(receipt.message_id),
+                None,
+            )))?;
         Ok(())
     }
 
@@ -347,17 +401,23 @@ mod tests {
         let client1 = Peer::new();
         let client2 = Peer::new();
 
-        let (client1_to_server_conn, server_to_client_1_conn, _) =
+        let (client1_to_server_conn, server_to_client_1_conn, _kill) =
             Connection::in_memory(cx.background());
-        let (client1_conn_id, io_task1, client1_incoming) =
-            client1.add_connection(client1_to_server_conn).await;
-        let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await;
-
-        let (client2_to_server_conn, server_to_client_2_conn, _) =
+        let (client1_conn_id, io_task1, client1_incoming) = client1
+            .add_test_connection(client1_to_server_conn, cx.background())
+            .await;
+        let (_, io_task2, server_incoming1) = server
+            .add_test_connection(server_to_client_1_conn, cx.background())
+            .await;
+
+        let (client2_to_server_conn, server_to_client_2_conn, _kill) =
             Connection::in_memory(cx.background());
-        let (client2_conn_id, io_task3, client2_incoming) =
-            client2.add_connection(client2_to_server_conn).await;
-        let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await;
+        let (client2_conn_id, io_task3, client2_incoming) = client2
+            .add_test_connection(client2_to_server_conn, cx.background())
+            .await;
+        let (_, io_task4, server_incoming2) = server
+            .add_test_connection(server_to_client_2_conn, cx.background())
+            .await;
 
         executor.spawn(io_task1).detach();
         executor.spawn(io_task2).detach();
@@ -438,12 +498,14 @@ mod tests {
         let server = Peer::new();
         let client = Peer::new();
 
-        let (client_to_server_conn, server_to_client_conn, _) =
+        let (client_to_server_conn, server_to_client_conn, _kill) =
             Connection::in_memory(cx.background());
-        let (client_to_server_conn_id, io_task1, mut client_incoming) =
-            client.add_connection(client_to_server_conn).await;
-        let (server_to_client_conn_id, io_task2, mut server_incoming) =
-            server.add_connection(server_to_client_conn).await;
+        let (client_to_server_conn_id, io_task1, mut client_incoming) = client
+            .add_test_connection(client_to_server_conn, cx.background())
+            .await;
+        let (server_to_client_conn_id, io_task2, mut server_incoming) = server
+            .add_test_connection(server_to_client_conn, cx.background())
+            .await;
 
         executor.spawn(io_task1).detach();
         executor.spawn(io_task2).detach();
@@ -536,12 +598,14 @@ mod tests {
         let server = Peer::new();
         let client = Peer::new();
 
-        let (client_to_server_conn, server_to_client_conn, _) =
+        let (client_to_server_conn, server_to_client_conn, _kill) =
             Connection::in_memory(cx.background());
-        let (client_to_server_conn_id, io_task1, mut client_incoming) =
-            client.add_connection(client_to_server_conn).await;
-        let (server_to_client_conn_id, io_task2, mut server_incoming) =
-            server.add_connection(server_to_client_conn).await;
+        let (client_to_server_conn_id, io_task1, mut client_incoming) = client
+            .add_test_connection(client_to_server_conn, cx.background())
+            .await;
+        let (server_to_client_conn_id, io_task2, mut server_incoming) = server
+            .add_test_connection(server_to_client_conn, cx.background())
+            .await;
 
         executor.spawn(io_task1).detach();
         executor.spawn(io_task2).detach();
@@ -646,10 +710,12 @@ mod tests {
     async fn test_disconnect(cx: &mut TestAppContext) {
         let executor = cx.foreground();
 
-        let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
+        let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
 
         let client = Peer::new();
-        let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
+        let (connection_id, io_handler, mut incoming) = client
+            .add_test_connection(client_conn, cx.background())
+            .await;
 
         let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
         executor
@@ -680,10 +746,12 @@ mod tests {
     #[gpui::test(iterations = 50)]
     async fn test_io_error(cx: &mut TestAppContext) {
         let executor = cx.foreground();
-        let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
+        let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
 
         let client = Peer::new();
-        let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
+        let (connection_id, io_handler, mut incoming) = client
+            .add_test_connection(client_conn, cx.background())
+            .await;
         executor.spawn(io_handler).detach();
         executor
             .spawn(async move { incoming.next().await })

crates/rpc/src/proto.rs 🔗

@@ -2,7 +2,7 @@ use super::{ConnectionId, PeerId, TypedEnvelope};
 use anyhow::Result;
 use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
 use futures::{SinkExt as _, StreamExt as _};
-use prost::Message;
+use prost::Message as _;
 use std::any::{Any, TypeId};
 use std::{
     io,
@@ -283,6 +283,13 @@ pub struct MessageStream<S> {
     encoding_buffer: Vec<u8>,
 }
 
+#[derive(Debug)]
+pub enum Message {
+    Envelope(Envelope),
+    Ping,
+    Pong,
+}
+
 impl<S> MessageStream<S> {
     pub fn new(stream: S) -> Self {
         Self {
@@ -300,22 +307,37 @@ impl<S> MessageStream<S>
 where
     S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
 {
-    /// Write a given protobuf message to the stream.
-    pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
+    pub async fn write(&mut self, message: Message) -> Result<(), WebSocketError> {
         #[cfg(any(test, feature = "test-support"))]
         const COMPRESSION_LEVEL: i32 = -7;
 
         #[cfg(not(any(test, feature = "test-support")))]
         const COMPRESSION_LEVEL: i32 = 4;
 
-        self.encoding_buffer.resize(message.encoded_len(), 0);
-        self.encoding_buffer.clear();
-        message
-            .encode(&mut self.encoding_buffer)
-            .map_err(|err| io::Error::from(err))?;
-        let buffer =
-            zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL).unwrap();
-        self.stream.send(WebSocketMessage::Binary(buffer)).await?;
+        match message {
+            Message::Envelope(message) => {
+                self.encoding_buffer.resize(message.encoded_len(), 0);
+                self.encoding_buffer.clear();
+                message
+                    .encode(&mut self.encoding_buffer)
+                    .map_err(|err| io::Error::from(err))?;
+                let buffer =
+                    zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
+                        .unwrap();
+                self.stream.send(WebSocketMessage::Binary(buffer)).await?;
+            }
+            Message::Ping => {
+                self.stream
+                    .send(WebSocketMessage::Ping(Default::default()))
+                    .await?;
+            }
+            Message::Pong => {
+                self.stream
+                    .send(WebSocketMessage::Pong(Default::default()))
+                    .await?;
+            }
+        }
+
         Ok(())
     }
 }
@@ -324,8 +346,7 @@ impl<S> MessageStream<S>
 where
     S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
 {
-    /// Read a protobuf message of the given type from the stream.
-    pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
+    pub async fn read(&mut self) -> Result<Message, WebSocketError> {
         while let Some(bytes) = self.stream.next().await {
             match bytes? {
                 WebSocketMessage::Binary(bytes) => {
@@ -333,8 +354,10 @@ where
                     zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
                     let envelope = Envelope::decode(self.encoding_buffer.as_slice())
                         .map_err(io::Error::from)?;
-                    return Ok(envelope);
+                    return Ok(Message::Envelope(envelope));
                 }
+                WebSocketMessage::Ping(_) => return Ok(Message::Ping),
+                WebSocketMessage::Pong(_) => return Ok(Message::Pong),
                 WebSocketMessage::Close(_) => break,
                 _ => {}
             }

crates/server/Cargo.toml 🔗

@@ -16,6 +16,7 @@ required-features = ["seed-support"]
 collections = { path = "../collections" }
 rpc = { path = "../rpc" }
 anyhow = "1.0.40"
+async-io = "1.3"
 async-std = { version = "1.8.0", features = ["attributes"] }
 async-trait = "0.1.50"
 async-tungstenite = "0.16"

crates/server/src/rpc.rs 🔗

@@ -6,6 +6,7 @@ use super::{
     AppState,
 };
 use anyhow::anyhow;
+use async_io::Timer;
 use async_std::task;
 use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
 use collections::{HashMap, HashSet};
@@ -16,7 +17,12 @@ use rpc::{
     Connection, ConnectionId, Peer, TypedEnvelope,
 };
 use sha1::{Digest as _, Sha1};
-use std::{any::TypeId, future::Future, sync::Arc, time::Instant};
+use std::{
+    any::TypeId,
+    future::Future,
+    sync::Arc,
+    time::{Duration, Instant},
+};
 use store::{Store, Worktree};
 use surf::StatusCode;
 use tide::log;
@@ -40,10 +46,13 @@ pub struct Server {
     notifications: Option<mpsc::UnboundedSender<()>>,
 }
 
-pub trait Executor {
+pub trait Executor: Send + Clone {
+    type Timer: Send + Future;
     fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
+    fn timer(&self, duration: Duration) -> Self::Timer;
 }
 
+#[derive(Clone)]
 pub struct RealExecutor;
 
 const MESSAGE_COUNT_PER_PAGE: usize = 100;
@@ -167,8 +176,18 @@ impl Server {
     ) -> impl Future<Output = ()> {
         let mut this = self.clone();
         async move {
-            let (connection_id, handle_io, mut incoming_rx) =
-                this.peer.add_connection(connection).await;
+            let (connection_id, handle_io, mut incoming_rx) = this
+                .peer
+                .add_connection(connection, {
+                    let executor = executor.clone();
+                    move |duration| {
+                        let timer = executor.timer(duration);
+                        async move {
+                            timer.await;
+                        }
+                    }
+                })
+                .await;
 
             if let Some(send_connection_id) = send_connection_id.as_mut() {
                 let _ = send_connection_id.send(connection_id).await;
@@ -883,9 +902,15 @@ impl Server {
 }
 
 impl Executor for RealExecutor {
+    type Timer = Timer;
+
     fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
         task::spawn(future);
     }
+
+    fn timer(&self, duration: Duration) -> Self::Timer {
+        Timer::after(duration)
+    }
 }
 
 fn broadcast<F>(
@@ -1005,7 +1030,7 @@ mod tests {
     };
     use lsp;
     use parking_lot::Mutex;
-    use postage::{sink::Sink, watch};
+    use postage::{barrier, watch};
     use project::{
         fs::{FakeFs, Fs as _},
         search::SearchQuery,
@@ -1759,7 +1784,7 @@ mod tests {
     }
 
     #[gpui::test(iterations = 10)]
-    async fn test_peer_disconnection(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
+    async fn test_leaving_project(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
         cx_a.foreground().forbid_parking();
         let lang_registry = Arc::new(LanguageRegistry::new());
         let fs = FakeFs::new(cx_a.background());
@@ -1817,16 +1842,40 @@ mod tests {
         .await
         .unwrap();
 
-        // See that a guest has joined as client A.
+        // Client A sees that a guest has joined.
         project_a
             .condition(&cx_a, |p, _| p.collaborators().len() == 1)
             .await;
 
-        // Drop client B's connection and ensure client A observes client B leaving the worktree.
+        // Drop client B's connection and ensure client A observes client B leaving the project.
         client_b.disconnect(&cx_b.to_async()).unwrap();
         project_a
             .condition(&cx_a, |p, _| p.collaborators().len() == 0)
             .await;
+
+        // Rejoin the project as client B
+        let _project_b = Project::remote(
+            project_id,
+            client_b.clone(),
+            client_b.user_store.clone(),
+            lang_registry.clone(),
+            fs.clone(),
+            &mut cx_b.to_async(),
+        )
+        .await
+        .unwrap();
+
+        // Client A sees that a guest has re-joined.
+        project_a
+            .condition(&cx_a, |p, _| p.collaborators().len() == 1)
+            .await;
+
+        // Simulate connection loss for client B and ensure client A observes client B leaving the project.
+        server.disconnect_client(client_b.current_user_id(cx_b));
+        cx_a.foreground().advance_clock(Duration::from_secs(3));
+        project_a
+            .condition(&cx_a, |p, _| p.collaborators().len() == 0)
+            .await;
     }
 
     #[gpui::test(iterations = 10)]
@@ -2683,8 +2732,6 @@ mod tests {
             .read_with(cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
             .await;
 
-        eprintln!("sharing");
-
         project_a.update(cx_a, |p, cx| p.share(cx)).await.unwrap();
 
         // Join the worktree as client B.
@@ -3850,6 +3897,7 @@ mod tests {
         // Disconnect client B, ensuring we can still access its cached channel data.
         server.forbid_connections();
         server.disconnect_client(client_b.current_user_id(&cx_b));
+        cx_b.foreground().advance_clock(Duration::from_secs(3));
         while !matches!(
             status_b.next().await,
             Some(client::Status::ReconnectionError { .. })
@@ -4340,7 +4388,7 @@ mod tests {
         server: Arc<Server>,
         foreground: Rc<executor::Foreground>,
         notifications: mpsc::UnboundedReceiver<()>,
-        connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
+        connection_killers: Arc<Mutex<HashMap<UserId, barrier::Sender>>>,
         forbid_connections: Arc<AtomicBool>,
         _test_db: TestDb,
     }
@@ -4444,9 +4492,7 @@ mod tests {
         }
 
         fn disconnect_client(&self, user_id: UserId) {
-            if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
-                let _ = kill_conn.try_send(Some(()));
-            }
+            self.connection_killers.lock().remove(&user_id);
         }
 
         fn forbid_connections(&self) {
@@ -5031,9 +5077,15 @@ mod tests {
     }
 
     impl Executor for Arc<gpui::executor::Background> {
+        type Timer = gpui::executor::Timer;
+
         fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
             self.spawn(future).detach();
         }
+
+        fn timer(&self, duration: Duration) -> Self::Timer {
+            self.as_ref().timer(duration)
+        }
     }
 
     fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> {

crates/text/src/network.rs 🔗

@@ -0,0 +1,69 @@
+use clock::ReplicaId;
+
+pub struct Network<T: Clone, R: rand::Rng> {
+    inboxes: std::collections::BTreeMap<ReplicaId, Vec<Envelope<T>>>,
+    all_messages: Vec<T>,
+    rng: R,
+}
+
+#[derive(Clone)]
+struct Envelope<T: Clone> {
+    message: T,
+}
+
+impl<T: Clone, R: rand::Rng> Network<T, R> {
+    pub fn new(rng: R) -> Self {
+        Network {
+            inboxes: Default::default(),
+            all_messages: Vec::new(),
+            rng,
+        }
+    }
+
+    pub fn add_peer(&mut self, id: ReplicaId) {
+        self.inboxes.insert(id, Vec::new());
+    }
+
+    pub fn replicate(&mut self, old_replica_id: ReplicaId, new_replica_id: ReplicaId) {
+        self.inboxes
+            .insert(new_replica_id, self.inboxes[&old_replica_id].clone());
+    }
+
+    pub fn is_idle(&self) -> bool {
+        self.inboxes.values().all(|i| i.is_empty())
+    }
+
+    pub fn broadcast(&mut self, sender: ReplicaId, messages: Vec<T>) {
+        for (replica, inbox) in self.inboxes.iter_mut() {
+            if *replica != sender {
+                for message in &messages {
+                    // Insert one or more duplicates of this message, potentially *before* the previous
+                    // message sent by this peer to simulate out-of-order delivery.
+                    for _ in 0..self.rng.gen_range(1..4) {
+                        let insertion_index = self.rng.gen_range(0..inbox.len() + 1);
+                        inbox.insert(
+                            insertion_index,
+                            Envelope {
+                                message: message.clone(),
+                            },
+                        );
+                    }
+                }
+            }
+        }
+        self.all_messages.extend(messages);
+    }
+
+    pub fn has_unreceived(&self, receiver: ReplicaId) -> bool {
+        !self.inboxes[&receiver].is_empty()
+    }
+
+    pub fn receive(&mut self, receiver: ReplicaId) -> Vec<T> {
+        let inbox = self.inboxes.get_mut(&receiver).unwrap();
+        let count = self.rng.gen_range(0..inbox.len() + 1);
+        inbox
+            .drain(0..count)
+            .map(|envelope| envelope.message)
+            .collect()
+    }
+}

crates/text/src/tests.rs 🔗

@@ -1,4 +1,4 @@
-use super::*;
+use super::{network::Network, *};
 use clock::ReplicaId;
 use rand::prelude::*;
 use std::{
@@ -7,7 +7,6 @@ use std::{
     iter::Iterator,
     time::{Duration, Instant},
 };
-use util::test::Network;
 
 #[cfg(test)]
 #[ctor::ctor]

crates/text/src/text.rs 🔗

@@ -1,5 +1,7 @@
 mod anchor;
 pub mod locator;
+#[cfg(any(test, feature = "test-support"))]
+pub mod network;
 pub mod operation_queue;
 mod patch;
 mod point;

crates/util/Cargo.toml 🔗

@@ -7,10 +7,9 @@ edition = "2021"
 doctest = false
 
 [features]
-test-support = ["clock", "rand", "serde_json", "tempdir"]
+test-support = ["rand", "serde_json", "tempdir"]
 
 [dependencies]
-clock = { path = "../clock", optional = true }
 anyhow = "1.0.38"
 futures = "0.3"
 log = "0.4"

crates/util/src/lib.rs 🔗

@@ -123,6 +123,18 @@ where
     }
 }
 
+struct Defer<F: FnOnce()>(Option<F>);
+
+impl<F: FnOnce()> Drop for Defer<F> {
+    fn drop(&mut self) {
+        self.0.take().map(|f| f());
+    }
+}
+
+pub fn defer<F: FnOnce()>(f: F) -> impl Drop {
+    Defer(Some(f))
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;

crates/util/src/test.rs 🔗

@@ -1,75 +1,6 @@
-use clock::ReplicaId;
 use std::path::{Path, PathBuf};
 use tempdir::TempDir;
 
-#[derive(Clone)]
-struct Envelope<T: Clone> {
-    message: T,
-}
-
-pub struct Network<T: Clone, R: rand::Rng> {
-    inboxes: std::collections::BTreeMap<ReplicaId, Vec<Envelope<T>>>,
-    all_messages: Vec<T>,
-    rng: R,
-}
-
-impl<T: Clone, R: rand::Rng> Network<T, R> {
-    pub fn new(rng: R) -> Self {
-        Network {
-            inboxes: Default::default(),
-            all_messages: Vec::new(),
-            rng,
-        }
-    }
-
-    pub fn add_peer(&mut self, id: ReplicaId) {
-        self.inboxes.insert(id, Vec::new());
-    }
-
-    pub fn replicate(&mut self, old_replica_id: ReplicaId, new_replica_id: ReplicaId) {
-        self.inboxes
-            .insert(new_replica_id, self.inboxes[&old_replica_id].clone());
-    }
-
-    pub fn is_idle(&self) -> bool {
-        self.inboxes.values().all(|i| i.is_empty())
-    }
-
-    pub fn broadcast(&mut self, sender: ReplicaId, messages: Vec<T>) {
-        for (replica, inbox) in self.inboxes.iter_mut() {
-            if *replica != sender {
-                for message in &messages {
-                    // Insert one or more duplicates of this message, potentially *before* the previous
-                    // message sent by this peer to simulate out-of-order delivery.
-                    for _ in 0..self.rng.gen_range(1..4) {
-                        let insertion_index = self.rng.gen_range(0..inbox.len() + 1);
-                        inbox.insert(
-                            insertion_index,
-                            Envelope {
-                                message: message.clone(),
-                            },
-                        );
-                    }
-                }
-            }
-        }
-        self.all_messages.extend(messages);
-    }
-
-    pub fn has_unreceived(&self, receiver: ReplicaId) -> bool {
-        !self.inboxes[&receiver].is_empty()
-    }
-
-    pub fn receive(&mut self, receiver: ReplicaId) -> Vec<T> {
-        let inbox = self.inboxes.get_mut(&receiver).unwrap();
-        let count = self.rng.gen_range(0..inbox.len() + 1);
-        inbox
-            .drain(0..count)
-            .map(|envelope| envelope.message)
-            .collect()
-    }
-}
-
 pub fn temp_tree(tree: serde_json::Value) -> TempDir {
     let dir = TempDir::new("").unwrap();
     write_tree(dir.path(), tree);