.zed.toml 🔗
@@ -1 +1 @@
-collaborators = ["nathansobo", "as-cii", "maxbrunsfeld", "iamnbutler"]
+collaborators = ["nathansobo", "as-cii", "maxbrunsfeld", "iamnbutler", "Kethku"]
Max Brunsfeld created
Send heartbeats in both directions so the server can detect when clients disconnect
.zed.toml | 2
Cargo.lock | 6
crates/client/src/client.rs | 59 ++-------
crates/client/src/test.rs | 9 +
crates/clock/Cargo.toml | 1
crates/clock/src/clock.rs | 31 -----
crates/gpui/Cargo.toml | 1
crates/gpui/src/executor.rs | 147 +++++++++++++++++---------
crates/gpui/src/util.rs | 1
crates/language/src/proto.rs | 45 +++++--
crates/language/src/tests.rs | 3
crates/project/src/lsp_command.rs | 26 ++--
crates/project/src/project.rs | 26 ++--
crates/project/src/worktree.rs | 15 +-
crates/rpc/Cargo.toml | 2
crates/rpc/src/conn.rs | 124 ++++++++++-----------
crates/rpc/src/peer.rs | 182 ++++++++++++++++++++++----------
crates/rpc/src/proto.rs | 51 ++++++--
crates/server/Cargo.toml | 1
crates/server/src/rpc.rs | 80 +++++++++++--
crates/text/src/network.rs | 69 ++++++++++++
crates/text/src/tests.rs | 3
crates/text/src/text.rs | 2
crates/util/Cargo.toml | 3
crates/util/src/lib.rs | 12 ++
crates/util/src/test.rs | 69 ------------
26 files changed, 571 insertions(+), 399 deletions(-)
@@ -1 +1 @@
-collaborators = ["nathansobo", "as-cii", "maxbrunsfeld", "iamnbutler"]
+collaborators = ["nathansobo", "as-cii", "maxbrunsfeld", "iamnbutler", "Kethku"]
@@ -998,7 +998,6 @@ dependencies = [
name = "clock"
version = "0.1.0"
dependencies = [
- "rpc",
"smallvec",
]
@@ -2236,6 +2235,7 @@ dependencies = [
"tiny-skia",
"tree-sitter",
"usvg",
+ "util",
"waker-fn",
]
@@ -3959,6 +3959,7 @@ dependencies = [
"async-lock",
"async-tungstenite",
"base64 0.13.0",
+ "clock",
"futures",
"gpui",
"log",
@@ -3972,6 +3973,7 @@ dependencies = [
"smol",
"smol-timeout",
"tempdir",
+ "util",
"zstd",
]
@@ -5574,7 +5576,6 @@ name = "util"
version = "0.1.0"
dependencies = [
"anyhow",
- "clock",
"futures",
"log",
"rand 0.8.3",
@@ -5959,6 +5960,7 @@ name = "zed-server"
version = "0.1.0"
dependencies = [
"anyhow",
+ "async-io",
"async-sqlx-session",
"async-std",
"async-trait",
@@ -137,8 +137,8 @@ struct ClientState {
credentials: Option<Credentials>,
status: (watch::Sender<Status>, watch::Receiver<Status>),
entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
- _maintain_connection: Option<Task<()>>,
- heartbeat_interval: Duration,
+ _reconnect_task: Option<Task<()>>,
+ reconnect_interval: Duration,
models_by_entity_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakModelHandle>,
models_by_message_type: HashMap<TypeId, AnyWeakModelHandle>,
model_types_by_message_type: HashMap<TypeId, TypeId>,
@@ -168,8 +168,8 @@ impl Default for ClientState {
credentials: None,
status: watch::channel_with(Status::SignedOut),
entity_id_extractors: Default::default(),
- _maintain_connection: None,
- heartbeat_interval: Duration::from_secs(5),
+ _reconnect_task: None,
+ reconnect_interval: Duration::from_secs(5),
models_by_message_type: Default::default(),
models_by_entity_type_and_remote_id: Default::default(),
model_types_by_message_type: Default::default(),
@@ -236,7 +236,7 @@ impl Client {
#[cfg(any(test, feature = "test-support"))]
pub fn tear_down(&self) {
let mut state = self.state.write();
- state._maintain_connection.take();
+ state._reconnect_task.take();
state.message_handlers.clear();
state.models_by_message_type.clear();
state.models_by_entity_type_and_remote_id.clear();
@@ -283,21 +283,12 @@ impl Client {
match status {
Status::Connected { .. } => {
- let heartbeat_interval = state.heartbeat_interval;
- let this = self.clone();
- let foreground = cx.foreground();
- state._maintain_connection = Some(cx.foreground().spawn(async move {
- loop {
- foreground.timer(heartbeat_interval).await;
- let _ = this.request(proto::Ping {}).await;
- }
- }));
+ state._reconnect_task = None;
}
Status::ConnectionLost => {
let this = self.clone();
- let foreground = cx.foreground();
- let heartbeat_interval = state.heartbeat_interval;
- state._maintain_connection = Some(cx.spawn(|cx| async move {
+ let reconnect_interval = state.reconnect_interval;
+ state._reconnect_task = Some(cx.spawn(|cx| async move {
let mut rng = StdRng::from_entropy();
let mut delay = Duration::from_millis(100);
while let Err(error) = this.authenticate_and_connect(&cx).await {
@@ -308,15 +299,15 @@ impl Client {
},
&cx,
);
- foreground.timer(delay).await;
+ cx.background().timer(delay).await;
delay = delay
.mul_f32(rng.gen_range(1.0..=2.0))
- .min(heartbeat_interval);
+ .min(reconnect_interval);
}
}));
}
Status::SignedOut | Status::UpgradeRequired => {
- state._maintain_connection.take();
+ state._reconnect_task.take();
}
_ => {}
}
@@ -548,7 +539,11 @@ impl Client {
}
async fn set_connection(self: &Arc<Self>, conn: Connection, cx: &AsyncAppContext) {
- let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
+ let executor = cx.background();
+ let (connection_id, handle_io, mut incoming) = self
+ .peer
+ .add_connection(conn, move |duration| executor.timer(duration))
+ .await;
cx.foreground()
.spawn({
let cx = cx.clone();
@@ -940,26 +935,6 @@ mod tests {
use crate::test::{FakeHttpClient, FakeServer};
use gpui::TestAppContext;
- #[gpui::test(iterations = 10)]
- async fn test_heartbeat(cx: &mut TestAppContext) {
- cx.foreground().forbid_parking();
-
- let user_id = 5;
- let mut client = Client::new(FakeHttpClient::with_404_response());
- let server = FakeServer::for_client(user_id, &mut client, &cx).await;
-
- cx.foreground().advance_clock(Duration::from_secs(10));
- let ping = server.receive::<proto::Ping>().await.unwrap();
- server.respond(ping.receipt(), proto::Ack {}).await;
-
- cx.foreground().advance_clock(Duration::from_secs(10));
- let ping = server.receive::<proto::Ping>().await.unwrap();
- server.respond(ping.receipt(), proto::Ack {}).await;
-
- client.disconnect(&cx.to_async()).unwrap();
- assert!(server.receive::<proto::Ping>().await.is_err());
- }
-
#[gpui::test(iterations = 10)]
async fn test_reconnection(cx: &mut TestAppContext) {
cx.foreground().forbid_parking();
@@ -991,8 +966,6 @@ mod tests {
server.roll_access_token();
server.allow_connections();
cx.foreground().advance_clock(Duration::from_secs(10));
- assert_eq!(server.auth_count(), 1);
- cx.foreground().advance_clock(Duration::from_secs(10));
while !matches!(status.next().await, Some(Status::Connected { .. })) {}
assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token
}
@@ -6,6 +6,7 @@ use anyhow::{anyhow, Result};
use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt};
use gpui::{executor, ModelHandle, TestAppContext};
use parking_lot::Mutex;
+use postage::barrier;
use rpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
use std::{fmt, rc::Rc, sync::Arc};
@@ -22,6 +23,7 @@ struct FakeServerState {
connection_id: Option<ConnectionId>,
forbid_connections: bool,
auth_count: usize,
+ connection_killer: Option<barrier::Sender>,
access_token: usize,
}
@@ -74,12 +76,15 @@ impl FakeServer {
Err(EstablishConnectionError::Unauthorized)?
}
- let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
- let (connection_id, io, incoming) = peer.add_connection(server_conn).await;
+ let (client_conn, server_conn, kill) =
+ Connection::in_memory(cx.background());
+ let (connection_id, io, incoming) =
+ peer.add_test_connection(server_conn, cx.background()).await;
cx.background().spawn(io).detach();
let mut state = state.lock();
state.connection_id = Some(connection_id);
state.incoming = Some(incoming);
+ state.connection_killer = Some(kill);
Ok(client_conn)
})
}
@@ -9,4 +9,3 @@ doctest = false
[dependencies]
smallvec = { version = "1.6", features = ["union"] }
-rpc = { path = "../rpc" }
@@ -69,37 +69,6 @@ impl<'a> AddAssign<&'a Local> for Local {
#[derive(Clone, Default, Hash, Eq, PartialEq)]
pub struct Global(SmallVec<[u32; 8]>);
-impl From<Vec<rpc::proto::VectorClockEntry>> for Global {
- fn from(message: Vec<rpc::proto::VectorClockEntry>) -> Self {
- let mut version = Self::new();
- for entry in message {
- version.observe(Local {
- replica_id: entry.replica_id as ReplicaId,
- value: entry.timestamp,
- });
- }
- version
- }
-}
-
-impl<'a> From<&'a Global> for Vec<rpc::proto::VectorClockEntry> {
- fn from(version: &'a Global) -> Self {
- version
- .iter()
- .map(|entry| rpc::proto::VectorClockEntry {
- replica_id: entry.replica_id as u32,
- timestamp: entry.value,
- })
- .collect()
- }
-}
-
-impl From<Global> for Vec<rpc::proto::VectorClockEntry> {
- fn from(version: Global) -> Self {
- (&version).into()
- }
-}
-
impl Global {
pub fn new() -> Self {
Self::default()
@@ -14,6 +14,7 @@ test-support = ["backtrace", "dhat", "env_logger", "collections/test-support"]
[dependencies]
collections = { path = "../collections" }
gpui_macros = { path = "../gpui_macros" }
+util = { path = "../util" }
sum_tree = { path = "../sum_tree" }
async-task = "4.0.3"
backtrace = { version = "0.3", optional = true }
@@ -1,6 +1,6 @@
use anyhow::{anyhow, Result};
use async_task::Runnable;
-use smol::{channel, prelude::*, Executor, Timer};
+use smol::{channel, prelude::*, Executor};
use std::{
any::Any,
fmt::{self, Display},
@@ -86,6 +86,19 @@ pub struct Deterministic {
parker: parking_lot::Mutex<parking::Parker>,
}
+pub enum Timer {
+ Production(smol::Timer),
+ #[cfg(any(test, feature = "test-support"))]
+ Deterministic(DeterministicTimer),
+}
+
+#[cfg(any(test, feature = "test-support"))]
+pub struct DeterministicTimer {
+ rx: postage::barrier::Receiver,
+ id: usize,
+ state: Arc<parking_lot::Mutex<DeterministicState>>,
+}
+
#[cfg(any(test, feature = "test-support"))]
impl Deterministic {
pub fn new(seed: u64) -> Arc<Self> {
@@ -306,15 +319,82 @@ impl Deterministic {
None
}
- pub fn advance_clock(&self, duration: Duration) {
+ pub fn timer(&self, duration: Duration) -> Timer {
+ let (tx, rx) = postage::barrier::channel();
let mut state = self.state.lock();
- state.now += duration;
- let now = state.now;
- let mut pending_timers = mem::take(&mut state.pending_timers);
- drop(state);
+ let wakeup_at = state.now + duration;
+ let id = util::post_inc(&mut state.next_timer_id);
+ state.pending_timers.push((id, wakeup_at, tx));
+ let state = self.state.clone();
+ Timer::Deterministic(DeterministicTimer { rx, id, state })
+ }
+
+ pub fn advance_clock(&self, duration: Duration) {
+ let new_now = self.state.lock().now + duration;
+ loop {
+ self.run_until_parked();
+ let mut state = self.state.lock();
- pending_timers.retain(|(_, wakeup, _)| *wakeup > now);
- self.state.lock().pending_timers.extend(pending_timers);
+ if let Some((_, wakeup_time, _)) = state.pending_timers.first() {
+ let wakeup_time = *wakeup_time;
+ if wakeup_time < new_now {
+ let timer_count = state
+ .pending_timers
+ .iter()
+ .take_while(|(_, t, _)| *t == wakeup_time)
+ .count();
+ state.now = wakeup_time;
+ let timers_to_wake = state
+ .pending_timers
+ .drain(0..timer_count)
+ .collect::<Vec<_>>();
+ drop(state);
+ drop(timers_to_wake);
+ continue;
+ }
+ }
+
+ break;
+ }
+
+ self.state.lock().now = new_now;
+ }
+}
+
+impl Drop for Timer {
+ fn drop(&mut self) {
+ #[cfg(any(test, feature = "test-support"))]
+ if let Timer::Deterministic(DeterministicTimer { state, id, .. }) = self {
+ state
+ .lock()
+ .pending_timers
+ .retain(|(timer_id, _, _)| timer_id != id)
+ }
+ }
+}
+
+impl Future for Timer {
+ type Output = ();
+
+ fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+ match &mut *self {
+ #[cfg(any(test, feature = "test-support"))]
+ Self::Deterministic(DeterministicTimer { rx, .. }) => {
+ use postage::stream::{PollRecv, Stream as _};
+ smol::pin!(rx);
+ match rx.poll_recv(&mut postage::Context::from_waker(cx.waker())) {
+ PollRecv::Ready(()) | PollRecv::Closed => Poll::Ready(()),
+ PollRecv::Pending => Poll::Pending,
+ }
+ }
+ Self::Production(timer) => {
+ smol::pin!(timer);
+ match timer.poll(cx) {
+ Poll::Ready(_) => Poll::Ready(()),
+ Poll::Pending => Poll::Pending,
+ }
+ }
+ }
}
}
@@ -438,46 +518,6 @@ impl Foreground {
}
}
- pub async fn timer(&self, duration: Duration) {
- match self {
- #[cfg(any(test, feature = "test-support"))]
- Self::Deterministic { executor, .. } => {
- use postage::prelude::Stream as _;
-
- let (tx, mut rx) = postage::barrier::channel();
- let timer_id;
- {
- let mut state = executor.state.lock();
- let wakeup_at = state.now + duration;
- timer_id = util::post_inc(&mut state.next_timer_id);
- state.pending_timers.push((timer_id, wakeup_at, tx));
- }
-
- struct DropTimer<'a>(usize, &'a Foreground);
- impl<'a> Drop for DropTimer<'a> {
- fn drop(&mut self) {
- match self.1 {
- Foreground::Deterministic { executor, .. } => {
- executor
- .state
- .lock()
- .pending_timers
- .retain(|(timer_id, _, _)| *timer_id != self.0);
- }
- _ => unreachable!(),
- }
- }
- }
-
- let _guard = DropTimer(timer_id, self);
- rx.recv().await;
- }
- _ => {
- Timer::after(duration).await;
- }
- }
- }
-
#[cfg(any(test, feature = "test-support"))]
pub fn advance_clock(&self, duration: Duration) {
match self {
@@ -600,6 +640,14 @@ impl Background {
}
}
+ pub fn timer(&self, duration: Duration) -> Timer {
+ match self {
+ Background::Production { .. } => Timer::Production(smol::Timer::after(duration)),
+ #[cfg(any(test, feature = "test-support"))]
+ Background::Deterministic { executor } => executor.timer(duration),
+ }
+ }
+
#[cfg(any(test, feature = "test-support"))]
pub async fn simulate_random_delay(&self) {
use rand::prelude::*;
@@ -612,9 +660,6 @@ impl Background {
for _ in 0..yields {
yield_now().await;
}
-
- let delay = Duration::from_millis(executor.state.lock().rng.gen_range(0..100));
- executor.advance_clock(delay);
}
}
_ => panic!("this method can only be called on a deterministic executor"),
@@ -1,5 +1,6 @@
use smol::future::FutureExt;
use std::{future::Future, time::Duration};
+pub use util::*;
pub fn post_inc(value: &mut usize) -> usize {
let prev = *value;
@@ -25,13 +25,13 @@ pub fn serialize_operation(operation: &Operation) -> proto::Operation {
replica_id: undo.id.replica_id as u32,
local_timestamp: undo.id.value,
lamport_timestamp: lamport_timestamp.value,
- version: From::from(&undo.version),
+ version: serialize_version(&undo.version),
transaction_ranges: undo
.transaction_ranges
.iter()
.map(serialize_range)
.collect(),
- transaction_version: From::from(&undo.transaction_version),
+ transaction_version: serialize_version(&undo.transaction_version),
counts: undo
.counts
.iter()
@@ -77,7 +77,7 @@ pub fn serialize_edit_operation(operation: &EditOperation) -> proto::operation::
replica_id: operation.timestamp.replica_id as u32,
local_timestamp: operation.timestamp.local,
lamport_timestamp: operation.timestamp.lamport,
- version: From::from(&operation.version),
+ version: serialize_version(&operation.version),
ranges: operation.ranges.iter().map(serialize_range).collect(),
new_text: operation.new_text.clone(),
}
@@ -116,7 +116,7 @@ pub fn serialize_buffer_fragment(fragment: &text::Fragment) -> proto::BufferFrag
timestamp: clock.value,
})
.collect(),
- max_undos: From::from(&fragment.max_undos),
+ max_undos: serialize_version(&fragment.max_undos),
}
}
@@ -188,7 +188,7 @@ pub fn deserialize_operation(message: proto::Operation) -> Result<Operation> {
replica_id: undo.replica_id as ReplicaId,
value: undo.local_timestamp,
},
- version: undo.version.into(),
+ version: deserialize_version(undo.version),
counts: undo
.counts
.into_iter()
@@ -207,7 +207,7 @@ pub fn deserialize_operation(message: proto::Operation) -> Result<Operation> {
.into_iter()
.map(deserialize_range)
.collect(),
- transaction_version: undo.transaction_version.into(),
+ transaction_version: deserialize_version(undo.transaction_version),
},
}),
proto::operation::Variant::UpdateSelections(message) => {
@@ -260,7 +260,7 @@ pub fn deserialize_edit_operation(edit: proto::operation::Edit) -> EditOperation
local: edit.local_timestamp,
lamport: edit.lamport_timestamp,
},
- version: edit.version.into(),
+ version: deserialize_version(edit.version),
ranges: edit.ranges.into_iter().map(deserialize_range).collect(),
new_text: edit.new_text,
}
@@ -309,7 +309,7 @@ pub fn deserialize_buffer_fragment(
replica_id: entry.replica_id as ReplicaId,
value: entry.timestamp,
})),
- max_undos: From::from(message.max_undos),
+ max_undos: deserialize_version(message.max_undos),
}
}
@@ -472,8 +472,8 @@ pub fn serialize_transaction(transaction: &Transaction) -> proto::Transaction {
.copied()
.map(serialize_local_timestamp)
.collect(),
- start: (&transaction.start).into(),
- end: (&transaction.end).into(),
+ start: serialize_version(&transaction.start),
+ end: serialize_version(&transaction.end),
ranges: transaction.ranges.iter().map(serialize_range).collect(),
}
}
@@ -490,8 +490,8 @@ pub fn deserialize_transaction(transaction: proto::Transaction) -> Result<Transa
.into_iter()
.map(deserialize_local_timestamp)
.collect(),
- start: transaction.start.into(),
- end: transaction.end.into(),
+ start: deserialize_version(transaction.start.into()),
+ end: deserialize_version(transaction.end),
ranges: transaction
.ranges
.into_iter()
@@ -524,3 +524,24 @@ pub fn serialize_range(range: &Range<FullOffset>) -> proto::Range {
pub fn deserialize_range(range: proto::Range) -> Range<FullOffset> {
FullOffset(range.start as usize)..FullOffset(range.end as usize)
}
+
+pub fn deserialize_version(message: Vec<proto::VectorClockEntry>) -> clock::Global {
+ let mut version = clock::Global::new();
+ for entry in message {
+ version.observe(clock::Local {
+ replica_id: entry.replica_id as ReplicaId,
+ value: entry.timestamp,
+ });
+ }
+ version
+}
+
+pub fn serialize_version(version: &clock::Global) -> Vec<proto::VectorClockEntry> {
+ version
+ .iter()
+ .map(|entry| proto::VectorClockEntry {
+ replica_id: entry.replica_id as u32,
+ timestamp: entry.value,
+ })
+ .collect()
+}
@@ -11,8 +11,9 @@ use std::{
rc::Rc,
time::{Duration, Instant},
};
+use text::network::Network;
use unindent::Unindent as _;
-use util::{post_inc, test::Network};
+use util::post_inc;
#[cfg(test)]
#[ctor::ctor]
@@ -5,7 +5,7 @@ use client::{proto, PeerId};
use gpui::{AppContext, AsyncAppContext, ModelHandle};
use language::{
point_from_lsp,
- proto::{deserialize_anchor, serialize_anchor},
+ proto::{deserialize_anchor, deserialize_version, serialize_anchor, serialize_version},
range_from_lsp, Anchor, Bias, Buffer, PointUtf16, ToLspPosition, ToPointUtf16,
};
use lsp::{DocumentHighlightKind, ServerCapabilities};
@@ -126,7 +126,7 @@ impl LspCommand for PrepareRename {
position: Some(language::proto::serialize_anchor(
&buffer.anchor_before(self.position),
)),
- version: (&buffer.version()).into(),
+ version: serialize_version(&buffer.version()),
}
}
@@ -142,7 +142,7 @@ impl LspCommand for PrepareRename {
.ok_or_else(|| anyhow!("invalid position"))?;
buffer
.update(&mut cx, |buffer, _| {
- buffer.wait_for_version(message.version.into())
+ buffer.wait_for_version(deserialize_version(message.version))
})
.await;
@@ -166,7 +166,7 @@ impl LspCommand for PrepareRename {
end: range
.as_ref()
.map(|range| language::proto::serialize_anchor(&range.end)),
- version: buffer_version.into(),
+ version: serialize_version(buffer_version),
}
}
@@ -180,7 +180,7 @@ impl LspCommand for PrepareRename {
if message.can_rename {
buffer
.update(&mut cx, |buffer, _| {
- buffer.wait_for_version(message.version.into())
+ buffer.wait_for_version(deserialize_version(message.version))
})
.await;
let start = message.start.and_then(deserialize_anchor);
@@ -255,7 +255,7 @@ impl LspCommand for PerformRename {
&buffer.anchor_before(self.position),
)),
new_name: self.new_name.clone(),
- version: (&buffer.version()).into(),
+ version: serialize_version(&buffer.version()),
}
}
@@ -271,7 +271,7 @@ impl LspCommand for PerformRename {
.ok_or_else(|| anyhow!("invalid position"))?;
buffer
.update(&mut cx, |buffer, _| {
- buffer.wait_for_version(message.version.into())
+ buffer.wait_for_version(deserialize_version(message.version))
})
.await;
Ok(Self {
@@ -407,7 +407,7 @@ impl LspCommand for GetDefinition {
position: Some(language::proto::serialize_anchor(
&buffer.anchor_before(self.position),
)),
- version: (&buffer.version()).into(),
+ version: serialize_version(&buffer.version()),
}
}
@@ -423,7 +423,7 @@ impl LspCommand for GetDefinition {
.ok_or_else(|| anyhow!("invalid position"))?;
buffer
.update(&mut cx, |buffer, _| {
- buffer.wait_for_version(message.version.into())
+ buffer.wait_for_version(deserialize_version(message.version))
})
.await;
Ok(Self {
@@ -566,7 +566,7 @@ impl LspCommand for GetReferences {
position: Some(language::proto::serialize_anchor(
&buffer.anchor_before(self.position),
)),
- version: (&buffer.version()).into(),
+ version: serialize_version(&buffer.version()),
}
}
@@ -582,7 +582,7 @@ impl LspCommand for GetReferences {
.ok_or_else(|| anyhow!("invalid position"))?;
buffer
.update(&mut cx, |buffer, _| {
- buffer.wait_for_version(message.version.into())
+ buffer.wait_for_version(deserialize_version(message.version))
})
.await;
Ok(Self {
@@ -706,7 +706,7 @@ impl LspCommand for GetDocumentHighlights {
position: Some(language::proto::serialize_anchor(
&buffer.anchor_before(self.position),
)),
- version: (&buffer.version()).into(),
+ version: serialize_version(&buffer.version()),
}
}
@@ -722,7 +722,7 @@ impl LspCommand for GetDocumentHighlights {
.ok_or_else(|| anyhow!("invalid position"))?;
buffer
.update(&mut cx, |buffer, _| {
- buffer.wait_for_version(message.version.into())
+ buffer.wait_for_version(deserialize_version(message.version))
})
.await;
Ok(Self {
@@ -15,7 +15,7 @@ use gpui::{
UpgradeModelHandle, WeakModelHandle,
};
use language::{
- proto::{deserialize_anchor, serialize_anchor},
+ proto::{deserialize_anchor, deserialize_version, serialize_anchor, serialize_version},
range_from_lsp, Anchor, AnchorRangeExt, Bias, Buffer, CodeAction, CodeLabel, Completion,
Diagnostic, DiagnosticEntry, File as _, Language, LanguageRegistry, Operation, PointUtf16,
ToLspPosition, ToOffset, ToPointUtf16, Transaction,
@@ -1713,14 +1713,14 @@ impl Project {
project_id,
buffer_id,
position: Some(language::proto::serialize_anchor(&anchor)),
- version: (&source_buffer.version()).into(),
+ version: serialize_version(&source_buffer.version()),
};
cx.spawn_weak(|_, mut cx| async move {
let response = rpc.request(message).await?;
source_buffer_handle
.update(&mut cx, |buffer, _| {
- buffer.wait_for_version(response.version.into())
+ buffer.wait_for_version(deserialize_version(response.version))
})
.await;
@@ -1910,13 +1910,13 @@ impl Project {
buffer_id,
start: Some(language::proto::serialize_anchor(&range.start)),
end: Some(language::proto::serialize_anchor(&range.end)),
- version: (&version).into(),
+ version: serialize_version(&version),
})
.await?;
buffer_handle
.update(&mut cx, |buffer, _| {
- buffer.wait_for_version(response.version.into())
+ buffer.wait_for_version(deserialize_version(response.version))
})
.await;
@@ -2915,7 +2915,7 @@ impl Project {
mut cx: AsyncAppContext,
) -> Result<proto::BufferSaved> {
let buffer_id = envelope.payload.buffer_id;
- let requested_version = envelope.payload.version.try_into()?;
+ let requested_version = deserialize_version(envelope.payload.version);
let (project_id, buffer) = this.update(&mut cx, |this, cx| {
let project_id = this.remote_id().ok_or_else(|| anyhow!("not connected"))?;
@@ -2936,7 +2936,7 @@ impl Project {
Ok(proto::BufferSaved {
project_id,
buffer_id,
- version: (&saved_version).into(),
+ version: serialize_version(&saved_version),
mtime: Some(mtime.into()),
})
}
@@ -2981,7 +2981,7 @@ impl Project {
.position
.and_then(language::proto::deserialize_anchor)
.ok_or_else(|| anyhow!("invalid position"))?;
- let version = clock::Global::from(envelope.payload.version);
+ let version = deserialize_version(envelope.payload.version);
let buffer = this.read_with(&cx, |this, cx| {
this.opened_buffers
.get(&envelope.payload.buffer_id)
@@ -3001,7 +3001,7 @@ impl Project {
.iter()
.map(language::proto::serialize_completion)
.collect(),
- version: (&version).into(),
+ version: serialize_version(&version),
})
}
@@ -3062,7 +3062,7 @@ impl Project {
})?;
buffer
.update(&mut cx, |buffer, _| {
- buffer.wait_for_version(envelope.payload.version.into())
+ buffer.wait_for_version(deserialize_version(envelope.payload.version))
})
.await;
@@ -3077,7 +3077,7 @@ impl Project {
.iter()
.map(language::proto::serialize_code_action)
.collect(),
- version: (&version).into(),
+ version: serialize_version(&version),
})
}
@@ -3445,7 +3445,7 @@ impl Project {
_: Arc<Client>,
mut cx: AsyncAppContext,
) -> Result<()> {
- let version = envelope.payload.version.try_into()?;
+ let version = deserialize_version(envelope.payload.version);
let mtime = envelope
.payload
.mtime
@@ -3473,7 +3473,7 @@ impl Project {
mut cx: AsyncAppContext,
) -> Result<()> {
let payload = envelope.payload.clone();
- let version = payload.version.try_into()?;
+ let version = deserialize_version(payload.version);
let mtime = payload
.mtime
.ok_or_else(|| anyhow!("missing mtime"))?
@@ -17,7 +17,10 @@ use gpui::{
executor, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
Task,
};
-use language::{Buffer, DiagnosticEntry, Operation, PointUtf16, Rope};
+use language::{
+ proto::{deserialize_version, serialize_version},
+ Buffer, DiagnosticEntry, Operation, PointUtf16, Rope,
+};
use lazy_static::lazy_static;
use parking_lot::Mutex;
use postage::{
@@ -30,7 +33,7 @@ use smol::channel::{self, Sender};
use std::{
any::Any,
cmp::{self, Ordering},
- convert::{TryFrom, TryInto},
+ convert::TryFrom,
ffi::{OsStr, OsString},
fmt,
future::Future,
@@ -1423,7 +1426,7 @@ impl language::File for File {
rpc.send(proto::BufferSaved {
project_id,
buffer_id,
- version: (&version).into(),
+ version: serialize_version(&version),
mtime: Some(entry.mtime.into()),
})?;
}
@@ -1438,10 +1441,10 @@ impl language::File for File {
.request(proto::SaveBuffer {
project_id,
buffer_id,
- version: (&version).into(),
+ version: serialize_version(&version),
})
.await?;
- let version = response.version.try_into()?;
+ let version = deserialize_version(response.version);
let mtime = response
.mtime
.ok_or_else(|| anyhow!("missing mtime"))?
@@ -1518,7 +1521,7 @@ impl language::LocalFile for File {
.send(proto::BufferReloaded {
project_id,
buffer_id,
- version: version.into(),
+ version: serialize_version(&version),
mtime: Some(mtime.into()),
})
.log_err();
@@ -26,7 +26,9 @@ rsa = "0.4"
serde = { version = "1", features = ["derive"] }
smol-timeout = "0.6"
zstd = "0.9"
+clock = { path = "../clock" }
gpui = { path = "../gpui", optional = true }
+util = { path = "../util" }
[build-dependencies]
prost-build = "0.8"
@@ -1,6 +1,5 @@
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
-use futures::{SinkExt as _, Stream, StreamExt as _};
-use std::{io, task::Poll};
+use futures::{SinkExt as _, StreamExt as _};
pub struct Connection {
pub(crate) tx:
@@ -36,87 +35,82 @@ impl Connection {
#[cfg(any(test, feature = "test-support"))]
pub fn in_memory(
executor: std::sync::Arc<gpui::executor::Background>,
- ) -> (Self, Self, postage::watch::Sender<Option<()>>) {
- let (kill_tx, mut kill_rx) = postage::watch::channel_with(None);
- postage::stream::Stream::try_recv(&mut kill_rx).unwrap();
+ ) -> (Self, Self, postage::barrier::Sender) {
+ use postage::prelude::Stream;
- let (a_tx, a_rx) = Self::channel(kill_rx.clone(), executor.clone());
- let (b_tx, b_rx) = Self::channel(kill_rx, executor);
- (
+ let (kill_tx, kill_rx) = postage::barrier::channel();
+ let (a_tx, a_rx) = channel(kill_rx.clone(), executor.clone());
+ let (b_tx, b_rx) = channel(kill_rx, executor);
+ return (
Self { tx: a_tx, rx: b_rx },
Self { tx: b_tx, rx: a_rx },
kill_tx,
- )
- }
+ );
- #[cfg(any(test, feature = "test-support"))]
- fn channel(
- kill_rx: postage::watch::Receiver<Option<()>>,
- executor: std::sync::Arc<gpui::executor::Background>,
- ) -> (
- Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
- Box<dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>>,
- ) {
- use futures::channel::mpsc;
- use io::{Error, ErrorKind};
- use std::sync::Arc;
+ fn channel(
+ kill_rx: postage::barrier::Receiver,
+ executor: std::sync::Arc<gpui::executor::Background>,
+ ) -> (
+ Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
+ Box<
+ dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
+ >,
+ ) {
+ use futures::channel::mpsc;
+ use std::{
+ io::{Error, ErrorKind},
+ sync::Arc,
+ };
- let (tx, rx) = mpsc::unbounded::<WebSocketMessage>();
- let tx = tx
- .sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e)))
- .with({
- let executor = Arc::downgrade(&executor);
+ let (tx, rx) = mpsc::unbounded::<WebSocketMessage>();
+
+ let tx = tx
+ .sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e)))
+ .with({
+ let kill_rx = kill_rx.clone();
+ let executor = Arc::downgrade(&executor);
+ move |msg| {
+ let mut kill_rx = kill_rx.clone();
+ let executor = executor.clone();
+ Box::pin(async move {
+ if let Some(executor) = executor.upgrade() {
+ executor.simulate_random_delay().await;
+ }
+
+ // Writes to a half-open TCP connection will error.
+ if kill_rx.try_recv().is_ok() {
+ std::io::Result::Err(
+ Error::new(ErrorKind::Other, "connection lost").into(),
+ )?;
+ }
+
+ Ok(msg)
+ })
+ }
+ });
+
+ let rx = rx.then({
let kill_rx = kill_rx.clone();
+ let executor = Arc::downgrade(&executor);
move |msg| {
- let kill_rx = kill_rx.clone();
+ let mut kill_rx = kill_rx.clone();
let executor = executor.clone();
Box::pin(async move {
if let Some(executor) = executor.upgrade() {
executor.simulate_random_delay().await;
}
- if kill_rx.borrow().is_none() {
- Ok(msg)
- } else {
- Err(Error::new(ErrorKind::Other, "connection killed").into())
+
+ // Reads from a half-open TCP connection will hang.
+ if kill_rx.try_recv().is_ok() {
+ futures::future::pending::<()>().await;
}
+
+ Ok(msg)
})
}
});
- let rx = rx.then(move |msg| {
- let executor = Arc::downgrade(&executor);
- Box::pin(async move {
- if let Some(executor) = executor.upgrade() {
- executor.simulate_random_delay().await;
- }
- msg
- })
- });
- let rx = KillableReceiver { kill_rx, rx };
-
- (Box::new(tx), Box::new(rx))
- }
-}
-
-struct KillableReceiver<S> {
- rx: S,
- kill_rx: postage::watch::Receiver<Option<()>>,
-}
-
-impl<S: Unpin + Stream<Item = WebSocketMessage>> Stream for KillableReceiver<S> {
- type Item = Result<WebSocketMessage, WebSocketError>;
- fn poll_next(
- mut self: std::pin::Pin<&mut Self>,
- cx: &mut std::task::Context<'_>,
- ) -> Poll<Option<Self::Item>> {
- if let Poll::Ready(Some(Some(()))) = self.kill_rx.poll_next_unpin(cx) {
- Poll::Ready(Some(Err(io::Error::new(
- io::ErrorKind::Other,
- "connection killed",
- )
- .into())))
- } else {
- self.rx.poll_next_unpin(cx).map(|value| value.map(Ok))
+ (Box::new(tx), Box::new(rx))
}
}
}
@@ -88,13 +88,14 @@ pub struct Peer {
#[derive(Clone)]
pub struct ConnectionState {
- outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Envelope>,
+ outgoing_tx: futures::channel::mpsc::UnboundedSender<proto::Message>,
next_message_id: Arc<AtomicU32>,
response_channels:
Arc<Mutex<Option<HashMap<u32, oneshot::Sender<(proto::Envelope, barrier::Sender)>>>>>,
}
-const WRITE_TIMEOUT: Duration = Duration::from_secs(10);
+const KEEPALIVE_INTERVAL: Duration = Duration::from_secs(1);
+const WRITE_TIMEOUT: Duration = Duration::from_secs(2);
impl Peer {
pub fn new() -> Arc<Self> {
@@ -104,14 +105,20 @@ impl Peer {
})
}
- pub async fn add_connection(
+ pub async fn add_connection<F, Fut, Out>(
self: &Arc<Self>,
connection: Connection,
+ create_timer: F,
) -> (
ConnectionId,
impl Future<Output = anyhow::Result<()>> + Send,
BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
- ) {
+ )
+ where
+ F: Send + Fn(Duration) -> Fut,
+ Fut: Send + Future<Output = Out>,
+ Out: Send,
+ {
// For outgoing messages, use an unbounded channel so that application code
// can always send messages without yielding. For incoming messages, use a
// bounded channel so that other peers will receive backpressure if they send
@@ -121,7 +128,7 @@ impl Peer {
let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst));
let connection_state = ConnectionState {
- outgoing_tx,
+ outgoing_tx: outgoing_tx.clone(),
next_message_id: Default::default(),
response_channels: Arc::new(Mutex::new(Some(Default::default()))),
};
@@ -131,39 +138,59 @@ impl Peer {
let this = self.clone();
let response_channels = connection_state.response_channels.clone();
let handle_io = async move {
- let result = 'outer: loop {
- let read_message = reader.read_message().fuse();
+ let _end_connection = util::defer(|| {
+ response_channels.lock().take();
+ this.connections.write().remove(&connection_id);
+ });
+
+ // Send messages on this frequency so the connection isn't closed.
+ let keepalive_timer = create_timer(KEEPALIVE_INTERVAL).fuse();
+ futures::pin_mut!(keepalive_timer);
+
+ loop {
+ let read_message = reader.read().fuse();
futures::pin_mut!(read_message);
+
+ // Disconnect if we don't receive messages at least this frequently.
+ let receive_timeout = create_timer(3 * KEEPALIVE_INTERVAL).fuse();
+ futures::pin_mut!(receive_timeout);
+
loop {
futures::select_biased! {
outgoing = outgoing_rx.next().fuse() => match outgoing {
Some(outgoing) => {
- match writer.write_message(&outgoing).timeout(WRITE_TIMEOUT).await {
- None => break 'outer Err(anyhow!("timed out writing RPC message")),
- Some(Err(result)) => break 'outer Err(result).context("failed to write RPC message"),
- _ => {}
+ if let Some(result) = writer.write(outgoing).timeout(WRITE_TIMEOUT).await {
+ result.context("failed to write RPC message")?;
+ keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
+ } else {
+ Err(anyhow!("timed out writing message"))?;
}
}
- None => break 'outer Ok(()),
+ None => return Ok(()),
},
- incoming = read_message => match incoming {
- Ok(incoming) => {
+ incoming = read_message => {
+ let incoming = incoming.context("received invalid RPC message")?;
+ if let proto::Message::Envelope(incoming) = incoming {
if incoming_tx.send(incoming).await.is_err() {
- break 'outer Ok(());
+ return Ok(());
}
- break;
- }
- Err(error) => {
- break 'outer Err(error).context("received invalid RPC message")
}
+ break;
},
+ _ = keepalive_timer => {
+ if let Some(result) = writer.write(proto::Message::Ping).timeout(WRITE_TIMEOUT).await {
+ result.context("failed to send keepalive")?;
+ keepalive_timer.set(create_timer(KEEPALIVE_INTERVAL).fuse());
+ } else {
+ Err(anyhow!("timed out sending keepalive"))?;
+ }
+ }
+ _ = receive_timeout => {
+ Err(anyhow!("delay between messages too long"))?
+ }
}
}
- };
-
- response_channels.lock().take();
- this.connections.write().remove(&connection_id);
- result
+ }
};
let response_channels = connection_state.response_channels.clone();
@@ -191,18 +218,31 @@ impl Peer {
None
} else {
- if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
- Some(envelope)
- } else {
+ proto::build_typed_envelope(connection_id, incoming).or_else(|| {
log::error!("unable to construct a typed envelope");
None
- }
+ })
}
}
});
(connection_id, handle_io, incoming_rx.boxed())
}
+ #[cfg(any(test, feature = "test-support"))]
+ pub async fn add_test_connection(
+ self: &Arc<Self>,
+ connection: Connection,
+ executor: Arc<gpui::executor::Background>,
+ ) -> (
+ ConnectionId,
+ impl Future<Output = anyhow::Result<()>> + Send,
+ BoxStream<'static, Box<dyn AnyTypedEnvelope>>,
+ ) {
+ let executor = executor.clone();
+ self.add_connection(connection, move |duration| executor.timer(duration))
+ .await
+ }
+
pub fn disconnect(&self, connection_id: ConnectionId) {
self.connections.write().remove(&connection_id);
}
@@ -245,11 +285,11 @@ impl Peer {
.insert(message_id, tx);
connection
.outgoing_tx
- .unbounded_send(request.into_envelope(
+ .unbounded_send(proto::Message::Envelope(request.into_envelope(
message_id,
None,
original_sender_id.map(|id| id.0),
- ))
+ )))
.map_err(|_| anyhow!("connection was closed"))?;
Ok(())
});
@@ -272,7 +312,9 @@ impl Peer {
.fetch_add(1, atomic::Ordering::SeqCst);
connection
.outgoing_tx
- .unbounded_send(message.into_envelope(message_id, None, None))?;
+ .unbounded_send(proto::Message::Envelope(
+ message.into_envelope(message_id, None, None),
+ ))?;
Ok(())
}
@@ -288,7 +330,11 @@ impl Peer {
.fetch_add(1, atomic::Ordering::SeqCst);
connection
.outgoing_tx
- .unbounded_send(message.into_envelope(message_id, None, Some(sender_id.0)))?;
+ .unbounded_send(proto::Message::Envelope(message.into_envelope(
+ message_id,
+ None,
+ Some(sender_id.0),
+ )))?;
Ok(())
}
@@ -303,7 +349,11 @@ impl Peer {
.fetch_add(1, atomic::Ordering::SeqCst);
connection
.outgoing_tx
- .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
+ .unbounded_send(proto::Message::Envelope(response.into_envelope(
+ message_id,
+ Some(receipt.message_id),
+ None,
+ )))?;
Ok(())
}
@@ -318,7 +368,11 @@ impl Peer {
.fetch_add(1, atomic::Ordering::SeqCst);
connection
.outgoing_tx
- .unbounded_send(response.into_envelope(message_id, Some(receipt.message_id), None))?;
+ .unbounded_send(proto::Message::Envelope(response.into_envelope(
+ message_id,
+ Some(receipt.message_id),
+ None,
+ )))?;
Ok(())
}
@@ -347,17 +401,23 @@ mod tests {
let client1 = Peer::new();
let client2 = Peer::new();
- let (client1_to_server_conn, server_to_client_1_conn, _) =
+ let (client1_to_server_conn, server_to_client_1_conn, _kill) =
Connection::in_memory(cx.background());
- let (client1_conn_id, io_task1, client1_incoming) =
- client1.add_connection(client1_to_server_conn).await;
- let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await;
-
- let (client2_to_server_conn, server_to_client_2_conn, _) =
+ let (client1_conn_id, io_task1, client1_incoming) = client1
+ .add_test_connection(client1_to_server_conn, cx.background())
+ .await;
+ let (_, io_task2, server_incoming1) = server
+ .add_test_connection(server_to_client_1_conn, cx.background())
+ .await;
+
+ let (client2_to_server_conn, server_to_client_2_conn, _kill) =
Connection::in_memory(cx.background());
- let (client2_conn_id, io_task3, client2_incoming) =
- client2.add_connection(client2_to_server_conn).await;
- let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await;
+ let (client2_conn_id, io_task3, client2_incoming) = client2
+ .add_test_connection(client2_to_server_conn, cx.background())
+ .await;
+ let (_, io_task4, server_incoming2) = server
+ .add_test_connection(server_to_client_2_conn, cx.background())
+ .await;
executor.spawn(io_task1).detach();
executor.spawn(io_task2).detach();
@@ -438,12 +498,14 @@ mod tests {
let server = Peer::new();
let client = Peer::new();
- let (client_to_server_conn, server_to_client_conn, _) =
+ let (client_to_server_conn, server_to_client_conn, _kill) =
Connection::in_memory(cx.background());
- let (client_to_server_conn_id, io_task1, mut client_incoming) =
- client.add_connection(client_to_server_conn).await;
- let (server_to_client_conn_id, io_task2, mut server_incoming) =
- server.add_connection(server_to_client_conn).await;
+ let (client_to_server_conn_id, io_task1, mut client_incoming) = client
+ .add_test_connection(client_to_server_conn, cx.background())
+ .await;
+ let (server_to_client_conn_id, io_task2, mut server_incoming) = server
+ .add_test_connection(server_to_client_conn, cx.background())
+ .await;
executor.spawn(io_task1).detach();
executor.spawn(io_task2).detach();
@@ -536,12 +598,14 @@ mod tests {
let server = Peer::new();
let client = Peer::new();
- let (client_to_server_conn, server_to_client_conn, _) =
+ let (client_to_server_conn, server_to_client_conn, _kill) =
Connection::in_memory(cx.background());
- let (client_to_server_conn_id, io_task1, mut client_incoming) =
- client.add_connection(client_to_server_conn).await;
- let (server_to_client_conn_id, io_task2, mut server_incoming) =
- server.add_connection(server_to_client_conn).await;
+ let (client_to_server_conn_id, io_task1, mut client_incoming) = client
+ .add_test_connection(client_to_server_conn, cx.background())
+ .await;
+ let (server_to_client_conn_id, io_task2, mut server_incoming) = server
+ .add_test_connection(server_to_client_conn, cx.background())
+ .await;
executor.spawn(io_task1).detach();
executor.spawn(io_task2).detach();
@@ -646,10 +710,12 @@ mod tests {
async fn test_disconnect(cx: &mut TestAppContext) {
let executor = cx.foreground();
- let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
+ let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
let client = Peer::new();
- let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
+ let (connection_id, io_handler, mut incoming) = client
+ .add_test_connection(client_conn, cx.background())
+ .await;
let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel();
executor
@@ -680,10 +746,12 @@ mod tests {
#[gpui::test(iterations = 50)]
async fn test_io_error(cx: &mut TestAppContext) {
let executor = cx.foreground();
- let (client_conn, mut server_conn, _) = Connection::in_memory(cx.background());
+ let (client_conn, mut server_conn, _kill) = Connection::in_memory(cx.background());
let client = Peer::new();
- let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await;
+ let (connection_id, io_handler, mut incoming) = client
+ .add_test_connection(client_conn, cx.background())
+ .await;
executor.spawn(io_handler).detach();
executor
.spawn(async move { incoming.next().await })
@@ -2,7 +2,7 @@ use super::{ConnectionId, PeerId, TypedEnvelope};
use anyhow::Result;
use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
use futures::{SinkExt as _, StreamExt as _};
-use prost::Message;
+use prost::Message as _;
use std::any::{Any, TypeId};
use std::{
io,
@@ -283,6 +283,13 @@ pub struct MessageStream<S> {
encoding_buffer: Vec<u8>,
}
+#[derive(Debug)]
+pub enum Message {
+ Envelope(Envelope),
+ Ping,
+ Pong,
+}
+
impl<S> MessageStream<S> {
pub fn new(stream: S) -> Self {
Self {
@@ -300,22 +307,37 @@ impl<S> MessageStream<S>
where
S: futures::Sink<WebSocketMessage, Error = WebSocketError> + Unpin,
{
- /// Write a given protobuf message to the stream.
- pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
+ pub async fn write(&mut self, message: Message) -> Result<(), WebSocketError> {
#[cfg(any(test, feature = "test-support"))]
const COMPRESSION_LEVEL: i32 = -7;
#[cfg(not(any(test, feature = "test-support")))]
const COMPRESSION_LEVEL: i32 = 4;
- self.encoding_buffer.resize(message.encoded_len(), 0);
- self.encoding_buffer.clear();
- message
- .encode(&mut self.encoding_buffer)
- .map_err(|err| io::Error::from(err))?;
- let buffer =
- zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL).unwrap();
- self.stream.send(WebSocketMessage::Binary(buffer)).await?;
+ match message {
+ Message::Envelope(message) => {
+ self.encoding_buffer.resize(message.encoded_len(), 0);
+ self.encoding_buffer.clear();
+ message
+ .encode(&mut self.encoding_buffer)
+ .map_err(|err| io::Error::from(err))?;
+ let buffer =
+ zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL)
+ .unwrap();
+ self.stream.send(WebSocketMessage::Binary(buffer)).await?;
+ }
+ Message::Ping => {
+ self.stream
+ .send(WebSocketMessage::Ping(Default::default()))
+ .await?;
+ }
+ Message::Pong => {
+ self.stream
+ .send(WebSocketMessage::Pong(Default::default()))
+ .await?;
+ }
+ }
+
Ok(())
}
}
@@ -324,8 +346,7 @@ impl<S> MessageStream<S>
where
S: futures::Stream<Item = Result<WebSocketMessage, WebSocketError>> + Unpin,
{
- /// Read a protobuf message of the given type from the stream.
- pub async fn read_message(&mut self) -> Result<Envelope, WebSocketError> {
+ pub async fn read(&mut self) -> Result<Message, WebSocketError> {
while let Some(bytes) = self.stream.next().await {
match bytes? {
WebSocketMessage::Binary(bytes) => {
@@ -333,8 +354,10 @@ where
zstd::stream::copy_decode(bytes.as_slice(), &mut self.encoding_buffer).unwrap();
let envelope = Envelope::decode(self.encoding_buffer.as_slice())
.map_err(io::Error::from)?;
- return Ok(envelope);
+ return Ok(Message::Envelope(envelope));
}
+ WebSocketMessage::Ping(_) => return Ok(Message::Ping),
+ WebSocketMessage::Pong(_) => return Ok(Message::Pong),
WebSocketMessage::Close(_) => break,
_ => {}
}
@@ -16,6 +16,7 @@ required-features = ["seed-support"]
collections = { path = "../collections" }
rpc = { path = "../rpc" }
anyhow = "1.0.40"
+async-io = "1.3"
async-std = { version = "1.8.0", features = ["attributes"] }
async-trait = "0.1.50"
async-tungstenite = "0.16"
@@ -6,6 +6,7 @@ use super::{
AppState,
};
use anyhow::anyhow;
+use async_io::Timer;
use async_std::task;
use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
use collections::{HashMap, HashSet};
@@ -16,7 +17,12 @@ use rpc::{
Connection, ConnectionId, Peer, TypedEnvelope,
};
use sha1::{Digest as _, Sha1};
-use std::{any::TypeId, future::Future, sync::Arc, time::Instant};
+use std::{
+ any::TypeId,
+ future::Future,
+ sync::Arc,
+ time::{Duration, Instant},
+};
use store::{Store, Worktree};
use surf::StatusCode;
use tide::log;
@@ -40,10 +46,13 @@ pub struct Server {
notifications: Option<mpsc::UnboundedSender<()>>,
}
-pub trait Executor {
+pub trait Executor: Send + Clone {
+ type Timer: Send + Future;
fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
+ fn timer(&self, duration: Duration) -> Self::Timer;
}
+#[derive(Clone)]
pub struct RealExecutor;
const MESSAGE_COUNT_PER_PAGE: usize = 100;
@@ -167,8 +176,18 @@ impl Server {
) -> impl Future<Output = ()> {
let mut this = self.clone();
async move {
- let (connection_id, handle_io, mut incoming_rx) =
- this.peer.add_connection(connection).await;
+ let (connection_id, handle_io, mut incoming_rx) = this
+ .peer
+ .add_connection(connection, {
+ let executor = executor.clone();
+ move |duration| {
+ let timer = executor.timer(duration);
+ async move {
+ timer.await;
+ }
+ }
+ })
+ .await;
if let Some(send_connection_id) = send_connection_id.as_mut() {
let _ = send_connection_id.send(connection_id).await;
@@ -883,9 +902,15 @@ impl Server {
}
impl Executor for RealExecutor {
+ type Timer = Timer;
+
fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
task::spawn(future);
}
+
+ fn timer(&self, duration: Duration) -> Self::Timer {
+ Timer::after(duration)
+ }
}
fn broadcast<F>(
@@ -1005,7 +1030,7 @@ mod tests {
};
use lsp;
use parking_lot::Mutex;
- use postage::{sink::Sink, watch};
+ use postage::{barrier, watch};
use project::{
fs::{FakeFs, Fs as _},
search::SearchQuery,
@@ -1759,7 +1784,7 @@ mod tests {
}
#[gpui::test(iterations = 10)]
- async fn test_peer_disconnection(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
+ async fn test_leaving_project(cx_a: &mut TestAppContext, cx_b: &mut TestAppContext) {
cx_a.foreground().forbid_parking();
let lang_registry = Arc::new(LanguageRegistry::new());
let fs = FakeFs::new(cx_a.background());
@@ -1817,16 +1842,40 @@ mod tests {
.await
.unwrap();
- // See that a guest has joined as client A.
+ // Client A sees that a guest has joined.
project_a
.condition(&cx_a, |p, _| p.collaborators().len() == 1)
.await;
- // Drop client B's connection and ensure client A observes client B leaving the worktree.
+ // Drop client B's connection and ensure client A observes client B leaving the project.
client_b.disconnect(&cx_b.to_async()).unwrap();
project_a
.condition(&cx_a, |p, _| p.collaborators().len() == 0)
.await;
+
+ // Rejoin the project as client B
+ let _project_b = Project::remote(
+ project_id,
+ client_b.clone(),
+ client_b.user_store.clone(),
+ lang_registry.clone(),
+ fs.clone(),
+ &mut cx_b.to_async(),
+ )
+ .await
+ .unwrap();
+
+ // Client A sees that a guest has re-joined.
+ project_a
+ .condition(&cx_a, |p, _| p.collaborators().len() == 1)
+ .await;
+
+ // Simulate connection loss for client B and ensure client A observes client B leaving the project.
+ server.disconnect_client(client_b.current_user_id(cx_b));
+ cx_a.foreground().advance_clock(Duration::from_secs(3));
+ project_a
+ .condition(&cx_a, |p, _| p.collaborators().len() == 0)
+ .await;
}
#[gpui::test(iterations = 10)]
@@ -2683,8 +2732,6 @@ mod tests {
.read_with(cx_a, |tree, _| tree.as_local().unwrap().scan_complete())
.await;
- eprintln!("sharing");
-
project_a.update(cx_a, |p, cx| p.share(cx)).await.unwrap();
// Join the worktree as client B.
@@ -3850,6 +3897,7 @@ mod tests {
// Disconnect client B, ensuring we can still access its cached channel data.
server.forbid_connections();
server.disconnect_client(client_b.current_user_id(&cx_b));
+ cx_b.foreground().advance_clock(Duration::from_secs(3));
while !matches!(
status_b.next().await,
Some(client::Status::ReconnectionError { .. })
@@ -4340,7 +4388,7 @@ mod tests {
server: Arc<Server>,
foreground: Rc<executor::Foreground>,
notifications: mpsc::UnboundedReceiver<()>,
- connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
+ connection_killers: Arc<Mutex<HashMap<UserId, barrier::Sender>>>,
forbid_connections: Arc<AtomicBool>,
_test_db: TestDb,
}
@@ -4444,9 +4492,7 @@ mod tests {
}
fn disconnect_client(&self, user_id: UserId) {
- if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
- let _ = kill_conn.try_send(Some(()));
- }
+ self.connection_killers.lock().remove(&user_id);
}
fn forbid_connections(&self) {
@@ -5031,9 +5077,15 @@ mod tests {
}
impl Executor for Arc<gpui::executor::Background> {
+ type Timer = gpui::executor::Timer;
+
fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
self.spawn(future).detach();
}
+
+ fn timer(&self, duration: Duration) -> Self::Timer {
+ self.as_ref().timer(duration)
+ }
}
fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> {
@@ -0,0 +1,69 @@
+use clock::ReplicaId;
+
+pub struct Network<T: Clone, R: rand::Rng> {
+ inboxes: std::collections::BTreeMap<ReplicaId, Vec<Envelope<T>>>,
+ all_messages: Vec<T>,
+ rng: R,
+}
+
+#[derive(Clone)]
+struct Envelope<T: Clone> {
+ message: T,
+}
+
+impl<T: Clone, R: rand::Rng> Network<T, R> {
+ pub fn new(rng: R) -> Self {
+ Network {
+ inboxes: Default::default(),
+ all_messages: Vec::new(),
+ rng,
+ }
+ }
+
+ pub fn add_peer(&mut self, id: ReplicaId) {
+ self.inboxes.insert(id, Vec::new());
+ }
+
+ pub fn replicate(&mut self, old_replica_id: ReplicaId, new_replica_id: ReplicaId) {
+ self.inboxes
+ .insert(new_replica_id, self.inboxes[&old_replica_id].clone());
+ }
+
+ pub fn is_idle(&self) -> bool {
+ self.inboxes.values().all(|i| i.is_empty())
+ }
+
+ pub fn broadcast(&mut self, sender: ReplicaId, messages: Vec<T>) {
+ for (replica, inbox) in self.inboxes.iter_mut() {
+ if *replica != sender {
+ for message in &messages {
+ // Insert one or more duplicates of this message, potentially *before* the previous
+ // message sent by this peer to simulate out-of-order delivery.
+ for _ in 0..self.rng.gen_range(1..4) {
+ let insertion_index = self.rng.gen_range(0..inbox.len() + 1);
+ inbox.insert(
+ insertion_index,
+ Envelope {
+ message: message.clone(),
+ },
+ );
+ }
+ }
+ }
+ }
+ self.all_messages.extend(messages);
+ }
+
+ pub fn has_unreceived(&self, receiver: ReplicaId) -> bool {
+ !self.inboxes[&receiver].is_empty()
+ }
+
+ pub fn receive(&mut self, receiver: ReplicaId) -> Vec<T> {
+ let inbox = self.inboxes.get_mut(&receiver).unwrap();
+ let count = self.rng.gen_range(0..inbox.len() + 1);
+ inbox
+ .drain(0..count)
+ .map(|envelope| envelope.message)
+ .collect()
+ }
+}
@@ -1,4 +1,4 @@
-use super::*;
+use super::{network::Network, *};
use clock::ReplicaId;
use rand::prelude::*;
use std::{
@@ -7,7 +7,6 @@ use std::{
iter::Iterator,
time::{Duration, Instant},
};
-use util::test::Network;
#[cfg(test)]
#[ctor::ctor]
@@ -1,5 +1,7 @@
mod anchor;
pub mod locator;
+#[cfg(any(test, feature = "test-support"))]
+pub mod network;
pub mod operation_queue;
mod patch;
mod point;
@@ -7,10 +7,9 @@ edition = "2021"
doctest = false
[features]
-test-support = ["clock", "rand", "serde_json", "tempdir"]
+test-support = ["rand", "serde_json", "tempdir"]
[dependencies]
-clock = { path = "../clock", optional = true }
anyhow = "1.0.38"
futures = "0.3"
log = "0.4"
@@ -123,6 +123,18 @@ where
}
}
+struct Defer<F: FnOnce()>(Option<F>);
+
+impl<F: FnOnce()> Drop for Defer<F> {
+ fn drop(&mut self) {
+ self.0.take().map(|f| f());
+ }
+}
+
+pub fn defer<F: FnOnce()>(f: F) -> impl Drop {
+ Defer(Some(f))
+}
+
#[cfg(test)]
mod tests {
use super::*;
@@ -1,75 +1,6 @@
-use clock::ReplicaId;
use std::path::{Path, PathBuf};
use tempdir::TempDir;
-#[derive(Clone)]
-struct Envelope<T: Clone> {
- message: T,
-}
-
-pub struct Network<T: Clone, R: rand::Rng> {
- inboxes: std::collections::BTreeMap<ReplicaId, Vec<Envelope<T>>>,
- all_messages: Vec<T>,
- rng: R,
-}
-
-impl<T: Clone, R: rand::Rng> Network<T, R> {
- pub fn new(rng: R) -> Self {
- Network {
- inboxes: Default::default(),
- all_messages: Vec::new(),
- rng,
- }
- }
-
- pub fn add_peer(&mut self, id: ReplicaId) {
- self.inboxes.insert(id, Vec::new());
- }
-
- pub fn replicate(&mut self, old_replica_id: ReplicaId, new_replica_id: ReplicaId) {
- self.inboxes
- .insert(new_replica_id, self.inboxes[&old_replica_id].clone());
- }
-
- pub fn is_idle(&self) -> bool {
- self.inboxes.values().all(|i| i.is_empty())
- }
-
- pub fn broadcast(&mut self, sender: ReplicaId, messages: Vec<T>) {
- for (replica, inbox) in self.inboxes.iter_mut() {
- if *replica != sender {
- for message in &messages {
- // Insert one or more duplicates of this message, potentially *before* the previous
- // message sent by this peer to simulate out-of-order delivery.
- for _ in 0..self.rng.gen_range(1..4) {
- let insertion_index = self.rng.gen_range(0..inbox.len() + 1);
- inbox.insert(
- insertion_index,
- Envelope {
- message: message.clone(),
- },
- );
- }
- }
- }
- }
- self.all_messages.extend(messages);
- }
-
- pub fn has_unreceived(&self, receiver: ReplicaId) -> bool {
- !self.inboxes[&receiver].is_empty()
- }
-
- pub fn receive(&mut self, receiver: ReplicaId) -> Vec<T> {
- let inbox = self.inboxes.get_mut(&receiver).unwrap();
- let count = self.rng.gen_range(0..inbox.len() + 1);
- inbox
- .drain(0..count)
- .map(|envelope| envelope.message)
- .collect()
- }
-}
-
pub fn temp_tree(tree: serde_json::Value) -> TempDir {
let dir = TempDir::new("").unwrap();
write_tree(dir.path(), tree);