From 900010160f1c9ffc909bcd8c8788cf93c495d3e6 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 8 Sep 2021 18:58:59 +0200 Subject: [PATCH] WIP Co-Authored-By: Max Brunsfeld --- gpui/src/executor.rs | 38 ++++++++++++++- server/src/rpc.rs | 2 +- zed/src/channel.rs | 62 +++---------------------- zed/src/rpc.rs | 107 +++++++++++++++++++++++++++++++++++-------- zed/src/test.rs | 54 +++++++++++++++++++++- 5 files changed, 182 insertions(+), 81 deletions(-) diff --git a/gpui/src/executor.rs b/gpui/src/executor.rs index b135f5034d6110bef3ada67e832503e64cc9606e..7a223a96d24a6cd19986282349ab69ca73d07eb2 100644 --- a/gpui/src/executor.rs +++ b/gpui/src/executor.rs @@ -3,8 +3,9 @@ use async_task::Runnable; pub use async_task::Task; use backtrace::{Backtrace, BacktraceFmt, BytesOrWideString}; use parking_lot::Mutex; +use postage::{barrier, prelude::Stream as _}; use rand::prelude::*; -use smol::{channel, prelude::*, Executor}; +use smol::{channel, prelude::*, Executor, Timer}; use std::{ fmt::{self, Debug}, marker::PhantomData, @@ -18,7 +19,7 @@ use std::{ }, task::{Context, Poll}, thread, - time::Duration, + time::{Duration, Instant}, }; use waker_fn::waker_fn; @@ -49,6 +50,8 @@ struct DeterministicState { spawned_from_foreground: Vec<(Runnable, Backtrace)>, forbid_parking: bool, block_on_ticks: RangeInclusive, + now: Instant, + pending_sleeps: Vec<(Instant, barrier::Sender)>, } pub struct Deterministic { @@ -67,6 +70,8 @@ impl Deterministic { spawned_from_foreground: Default::default(), forbid_parking: false, block_on_ticks: 0..=1000, + now: Instant::now(), + pending_sleeps: Default::default(), })), parker: Default::default(), } @@ -407,6 +412,35 @@ impl Foreground { } } + pub async fn sleep(&self, duration: Duration) { + match self { + Self::Deterministic(executor) => { + let (tx, mut rx) = barrier::channel(); + { + let mut state = executor.state.lock(); + let wakeup_at = state.now + duration; + state.pending_sleeps.push((wakeup_at, tx)); + } + rx.recv().await; + } + _ => { + Timer::after(duration).await; + } + } + } + + pub fn advance_clock(&self, duration: Duration) { + match self { + Self::Deterministic(executor) => { + let mut state = executor.state.lock(); + state.now += duration; + let now = state.now; + state.pending_sleeps.retain(|(wakeup, _)| *wakeup > now); + } + _ => panic!("this method can only be called on a deterministic executor"), + } + } + pub fn set_block_on_ticks(&self, range: RangeInclusive) { match self { Self::Deterministic(executor) => executor.state.lock().block_on_ticks = range, diff --git a/server/src/rpc.rs b/server/src/rpc.rs index 7562bdf74b7f687ed968bd97ed36f99c4978c4a9..e1b1bce05860a223e2d0a169c663711e9c2524c7 100644 --- a/server/src/rpc.rs +++ b/server/src/rpc.rs @@ -1469,7 +1469,7 @@ mod tests { .await; // Drop client B's connection and ensure client A observes client B leaving the worktree. - client_b.disconnect().await.unwrap(); + client_b.disconnect(&cx_b.to_async()).await.unwrap(); worktree_a .condition(&cx_a, |tree, _| tree.peers().len() == 0) .await; diff --git a/zed/src/channel.rs b/zed/src/channel.rs index 38329c70d4bf39968df533b53ec8c22fa2c48d09..234e3e1e5f2382ac8ad6ca716de27d85caaeb240 100644 --- a/zed/src/channel.rs +++ b/zed/src/channel.rs @@ -443,9 +443,8 @@ impl<'a> sum_tree::SeekDimension<'a, ChannelMessageSummary> for Count { #[cfg(test)] mod tests { use super::*; + use crate::test::FakeServer; use gpui::TestAppContext; - use postage::mpsc::Receiver; - use zrpc::{test::Channel, ConnectionId, Peer, Receipt}; #[gpui::test] async fn test_channel_messages(mut cx: TestAppContext) { @@ -458,7 +457,7 @@ mod tests { channel_list.read_with(&cx, |list, _| assert_eq!(list.available_channels(), None)); // Get the available channels. - let get_channels = server.receive::().await; + let get_channels = server.receive::().await.unwrap(); server .respond( get_channels.receipt(), @@ -489,7 +488,7 @@ mod tests { }) .unwrap(); channel.read_with(&cx, |channel, _| assert!(channel.messages().is_empty())); - let join_channel = server.receive::().await; + let join_channel = server.receive::().await.unwrap(); server .respond( join_channel.receipt(), @@ -514,7 +513,7 @@ mod tests { .await; // Client requests all users for the received messages - let mut get_users = server.receive::().await; + let mut get_users = server.receive::().await.unwrap(); get_users.payload.user_ids.sort(); assert_eq!(get_users.payload.user_ids, vec![5, 6]); server @@ -571,7 +570,7 @@ mod tests { .await; // Client requests user for message since they haven't seen them yet - let get_users = server.receive::().await; + let get_users = server.receive::().await.unwrap(); assert_eq!(get_users.payload.user_ids, vec![7]); server .respond( @@ -607,7 +606,7 @@ mod tests { channel.update(&mut cx, |channel, cx| { assert!(channel.load_more_messages(cx)); }); - let get_messages = server.receive::().await; + let get_messages = server.receive::().await.unwrap(); assert_eq!(get_messages.payload.channel_id, 5); assert_eq!(get_messages.payload.before_message_id, 10); server @@ -653,53 +652,4 @@ mod tests { ); }); } - - struct FakeServer { - peer: Arc, - incoming: Receiver>, - connection_id: ConnectionId, - } - - impl FakeServer { - async fn for_client(user_id: u64, client: &Arc, cx: &TestAppContext) -> Self { - let (client_conn, server_conn) = Channel::bidirectional(); - let peer = Peer::new(); - let (connection_id, io, incoming) = peer.add_connection(server_conn).await; - cx.background().spawn(io).detach(); - - client - .set_connection(user_id, client_conn, &cx.to_async()) - .await - .unwrap(); - - Self { - peer, - incoming, - connection_id, - } - } - - async fn send(&self, message: T) { - self.peer.send(self.connection_id, message).await.unwrap(); - } - - async fn receive(&mut self) -> TypedEnvelope { - *self - .incoming - .recv() - .await - .unwrap() - .into_any() - .downcast::>() - .unwrap() - } - - async fn respond( - &self, - receipt: Receipt, - response: T::Response, - ) { - self.peer.respond(receipt, response).await.unwrap() - } - } } diff --git a/zed/src/rpc.rs b/zed/src/rpc.rs index f3df1e7b88d0fac18a9380339d38a76c56efbff3..b36ec4d376658e9362cb07ef6cb2f6821a824765 100644 --- a/zed/src/rpc.rs +++ b/zed/src/rpc.rs @@ -3,10 +3,12 @@ use anyhow::{anyhow, Context, Result}; use async_tungstenite::tungstenite::{ http::Request, Error as WebSocketError, Message as WebSocketMessage, }; +use futures::StreamExt as _; use gpui::{AsyncAppContext, Entity, ModelContext, Task}; use lazy_static::lazy_static; use parking_lot::RwLock; use postage::{prelude::Stream, watch}; +use smol::Timer; use std::{ any::TypeId, collections::HashMap, @@ -42,6 +44,10 @@ pub enum Status { user_id: u64, }, ConnectionLost, + Reconnecting, + ReconnectionError { + next_reconnection: Instant, + }, } struct ClientState { @@ -51,6 +57,8 @@ struct ClientState { (TypeId, u64), Box, &mut AsyncAppContext)>, >, + _maintain_connection: Option>, + heartbeat_interval: Duration, } impl Default for ClientState { @@ -59,6 +67,8 @@ impl Default for ClientState { status: watch::channel_with(Status::Disconnected), entity_id_extractors: Default::default(), model_handlers: Default::default(), + _maintain_connection: None, + heartbeat_interval: Duration::from_secs(5), } } } @@ -95,9 +105,35 @@ impl Client { self.state.read().status.1.clone() } - fn set_status(&self, status: Status) { + fn set_status(self: &Arc, status: Status, cx: &AsyncAppContext) { let mut state = self.state.write(); *state.status.0.borrow_mut() = status; + match status { + Status::Connected { .. } => { + let heartbeat_interval = state.heartbeat_interval; + let this = self.clone(); + let foreground = cx.foreground(); + state._maintain_connection = Some(cx.foreground().spawn(async move { + let mut next_ping_id = 0; + loop { + foreground.sleep(heartbeat_interval).await; + this.request(proto::Ping { id: next_ping_id }) + .await + .unwrap(); + next_ping_id += 1; + } + })); + } + Status::ConnectionLost => { + state._maintain_connection = Some(cx.foreground().spawn(async move { + // TODO: try to reconnect + })); + } + Status::Disconnected => { + state._maintain_connection.take(); + } + _ => {} + } } pub fn subscribe_from_model( @@ -167,14 +203,14 @@ impl Client { let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?; let user_id = user_id.parse::()?; - self.set_status(Status::Connecting); + self.set_status(Status::Connecting, cx); match self.connect(user_id, &access_token, cx).await { Ok(()) => { log::info!("connected to rpc address {}", *ZED_SERVER_URL); Ok(()) } Err(err) => { - self.set_status(Status::ConnectionError); + self.set_status(Status::ConnectionError, cx); Err(err) } } @@ -256,20 +292,24 @@ impl Client { .detach(); } - self.set_status(Status::Connected { - connection_id, - user_id, - }); + self.set_status( + Status::Connected { + connection_id, + user_id, + }, + cx, + ); let handle_io = cx.background().spawn(handle_io); let this = self.clone(); + let cx = cx.clone(); cx.foreground() .spawn(async move { match handle_io.await { - Ok(()) => this.set_status(Status::Disconnected), + Ok(()) => this.set_status(Status::Disconnected, &cx), Err(err) => { log::error!("connection error: {:?}", err); - this.set_status(Status::ConnectionLost); + this.set_status(Status::ConnectionLost, &cx); } } }) @@ -359,10 +399,10 @@ impl Client { }) } - pub async fn disconnect(&self) -> Result<()> { + pub async fn disconnect(self: &Arc, cx: &AsyncAppContext) -> Result<()> { let conn_id = self.connection_id()?; self.peer.disconnect(conn_id).await; - self.set_status(Status::Disconnected); + self.set_status(Status::Disconnected, cx); Ok(()) } @@ -444,13 +484,40 @@ const LOGIN_RESPONSE: &'static str = " "; -#[test] -fn test_encode_and_decode_worktree_url() { - let url = encode_worktree_url(5, "deadbeef"); - assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string()))); - assert_eq!( - decode_worktree_url(&format!("\n {}\t", url)), - Some((5, "deadbeef".to_string())) - ); - assert_eq!(decode_worktree_url("not://the-right-format"), None); +#[cfg(test)] +mod tests { + use super::*; + use crate::test::FakeServer; + use gpui::TestAppContext; + + #[gpui::test(iterations = 1000)] + async fn test_heartbeat(cx: TestAppContext) { + let user_id = 5; + let client = Client::new(); + + client.state.write().heartbeat_interval = Duration::from_millis(1); + let mut server = FakeServer::for_client(user_id, &client, &cx).await; + + let ping = server.receive::().await.unwrap(); + assert_eq!(ping.payload.id, 0); + server.respond(ping.receipt(), proto::Pong { id: 0 }).await; + + let ping = server.receive::().await.unwrap(); + assert_eq!(ping.payload.id, 1); + server.respond(ping.receipt(), proto::Pong { id: 1 }).await; + + client.disconnect(&cx.to_async()).await.unwrap(); + assert!(server.receive::().await.is_err()); + } + + #[test] + fn test_encode_and_decode_worktree_url() { + let url = encode_worktree_url(5, "deadbeef"); + assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string()))); + assert_eq!( + decode_worktree_url(&format!("\n {}\t", url)), + Some((5, "deadbeef".to_string())) + ); + assert_eq!(decode_worktree_url("not://the-right-format"), None); + } } diff --git a/zed/src/test.rs b/zed/src/test.rs index b917e428f683df23a6b6cd23f51e1c95be849d27..f34ff550149d2fe5c7da773b8c3aec07b869f15c 100644 --- a/zed/src/test.rs +++ b/zed/src/test.rs @@ -3,14 +3,16 @@ use crate::{ channel::ChannelList, fs::RealFs, language::LanguageRegistry, - rpc, + rpc::{self, Client}, settings::{self, ThemeRegistry}, time::ReplicaId, user::UserStore, AppState, }; -use gpui::{Entity, ModelHandle, MutableAppContext}; +use anyhow::{anyhow, Result}; +use gpui::{Entity, ModelHandle, MutableAppContext, TestAppContext}; use parking_lot::Mutex; +use postage::{mpsc, prelude::Stream as _}; use smol::channel; use std::{ marker::PhantomData, @@ -18,6 +20,7 @@ use std::{ sync::Arc, }; use tempdir::TempDir; +use zrpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope}; #[cfg(feature = "test-support")] pub use zrpc::test::Channel; @@ -195,3 +198,50 @@ impl Observer { (observer, notify_rx) } } + +pub struct FakeServer { + peer: Arc, + incoming: mpsc::Receiver>, + connection_id: ConnectionId, +} + +impl FakeServer { + pub async fn for_client(user_id: u64, client: &Arc, cx: &TestAppContext) -> Self { + let (client_conn, server_conn) = zrpc::test::Channel::bidirectional(); + let peer = Peer::new(); + let (connection_id, io, incoming) = peer.add_connection(server_conn).await; + cx.background().spawn(io).detach(); + + client + .set_connection(user_id, client_conn, &cx.to_async()) + .await + .unwrap(); + + Self { + peer, + incoming, + connection_id, + } + } + + pub async fn send(&self, message: T) { + self.peer.send(self.connection_id, message).await.unwrap(); + } + + pub async fn receive(&mut self) -> Result> { + let message = self + .incoming + .recv() + .await + .ok_or_else(|| anyhow!("other half hung up"))?; + Ok(*message.into_any().downcast::>().unwrap()) + } + + pub async fn respond( + &self, + receipt: Receipt, + response: T::Response, + ) { + self.peer.respond(receipt, response).await.unwrap() + } +}