Detailed changes
@@ -3,8 +3,9 @@ use async_task::Runnable;
pub use async_task::Task;
use backtrace::{Backtrace, BacktraceFmt, BytesOrWideString};
use parking_lot::Mutex;
+use postage::{barrier, prelude::Stream as _};
use rand::prelude::*;
-use smol::{channel, prelude::*, Executor};
+use smol::{channel, prelude::*, Executor, Timer};
use std::{
fmt::{self, Debug},
marker::PhantomData,
@@ -18,7 +19,7 @@ use std::{
},
task::{Context, Poll},
thread,
- time::Duration,
+ time::{Duration, Instant},
};
use waker_fn::waker_fn;
@@ -49,6 +50,8 @@ struct DeterministicState {
spawned_from_foreground: Vec<(Runnable, Backtrace)>,
forbid_parking: bool,
block_on_ticks: RangeInclusive<usize>,
+ now: Instant,
+ pending_timers: Vec<(Instant, barrier::Sender)>,
}
pub struct Deterministic {
@@ -67,6 +70,8 @@ impl Deterministic {
spawned_from_foreground: Default::default(),
forbid_parking: false,
block_on_ticks: 0..=1000,
+ now: Instant::now(),
+ pending_timers: Default::default(),
})),
parker: Default::default(),
}
@@ -119,17 +124,39 @@ impl Deterministic {
T: 'static,
F: Future<Output = T> + 'static,
{
+ let woken = Arc::new(AtomicBool::new(false));
+ let mut future = Box::pin(future);
+ loop {
+ if let Some(result) = self.run_internal(woken.clone(), &mut future) {
+ return result;
+ }
+
+ if !woken.load(SeqCst) && self.state.lock().forbid_parking {
+ panic!("deterministic executor parked after a call to forbid_parking");
+ }
+
+ woken.store(false, SeqCst);
+ self.parker.lock().park();
+ }
+ }
+
+ fn run_until_parked(&self) {
+ let woken = Arc::new(AtomicBool::new(false));
+ let future = std::future::pending::<()>();
smol::pin!(future);
+ self.run_internal(woken, future);
+ }
+ pub fn run_internal<F, T>(&self, woken: Arc<AtomicBool>, mut future: F) -> Option<T>
+ where
+ T: 'static,
+ F: Future<Output = T> + Unpin,
+ {
let unparker = self.parker.lock().unparker();
- let woken = Arc::new(AtomicBool::new(false));
- let waker = {
- let woken = woken.clone();
- waker_fn(move || {
- woken.store(true, SeqCst);
- unparker.unpark();
- })
- };
+ let waker = waker_fn(move || {
+ woken.store(true, SeqCst);
+ unparker.unpark();
+ });
let mut cx = Context::from_waker(&waker);
let mut trace = Trace::default();
@@ -163,23 +190,17 @@ impl Deterministic {
runnable.run();
} else {
drop(state);
- if let Poll::Ready(result) = future.as_mut().poll(&mut cx) {
- return result;
+ if let Poll::Ready(result) = future.poll(&mut cx) {
+ return Some(result);
}
+
let state = self.state.lock();
if state.scheduled_from_foreground.is_empty()
&& state.scheduled_from_background.is_empty()
&& state.spawned_from_foreground.is_empty()
{
- if state.forbid_parking && !woken.load(SeqCst) {
- panic!("deterministic executor parked after a call to forbid_parking");
- }
- drop(state);
- woken.store(false, SeqCst);
- self.parker.lock().park();
+ return None;
}
-
- continue;
}
}
}
@@ -407,6 +428,41 @@ impl Foreground {
}
}
+ pub async fn timer(&self, duration: Duration) {
+ match self {
+ Self::Deterministic(executor) => {
+ let (tx, mut rx) = barrier::channel();
+ {
+ let mut state = executor.state.lock();
+ let wakeup_at = state.now + duration;
+ state.pending_timers.push((wakeup_at, tx));
+ }
+ rx.recv().await;
+ }
+ _ => {
+ Timer::after(duration).await;
+ }
+ }
+ }
+
+ pub fn advance_clock(&self, duration: Duration) {
+ match self {
+ Self::Deterministic(executor) => {
+ executor.run_until_parked();
+
+ let mut state = executor.state.lock();
+ state.now += duration;
+ let now = state.now;
+ let mut pending_timers = mem::take(&mut state.pending_timers);
+ drop(state);
+
+ pending_timers.retain(|(wakeup, _)| *wakeup > now);
+ executor.state.lock().pending_timers.extend(pending_timers);
+ }
+ _ => panic!("this method can only be called on a deterministic executor"),
+ }
+ }
+
pub fn set_block_on_ticks(&self, range: RangeInclusive<usize>) {
match self {
Self::Deterministic(executor) => executor.state.lock().block_on_ticks = range,
@@ -5,10 +5,7 @@ use super::{
};
use anyhow::anyhow;
use async_std::{sync::RwLock, task};
-use async_tungstenite::{
- tungstenite::{protocol::Role, Error as WebSocketError, Message as WebSocketMessage},
- WebSocketStream,
-};
+use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
use futures::{future::BoxFuture, FutureExt};
use postage::{mpsc, prelude::Sink as _, prelude::Stream as _};
use sha1::{Digest as _, Sha1};
@@ -30,7 +27,7 @@ use time::OffsetDateTime;
use zrpc::{
auth::random_token,
proto::{self, AnyTypedEnvelope, EnvelopedMessage},
- ConnectionId, Peer, TypedEnvelope,
+ Conn, ConnectionId, Peer, TypedEnvelope,
};
type ReplicaId = u16;
@@ -95,6 +92,7 @@ impl Server {
};
server
+ .add_handler(Server::ping)
.add_handler(Server::share_worktree)
.add_handler(Server::join_worktree)
.add_handler(Server::update_worktree)
@@ -133,19 +131,12 @@ impl Server {
self
}
- pub fn handle_connection<Conn>(
+ pub fn handle_connection(
self: &Arc<Self>,
connection: Conn,
addr: String,
user_id: UserId,
- ) -> impl Future<Output = ()>
- where
- Conn: 'static
- + futures::Sink<WebSocketMessage, Error = WebSocketError>
- + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
- + Send
- + Unpin,
- {
+ ) -> impl Future<Output = ()> {
let this = self.clone();
async move {
let (connection_id, handle_io, mut incoming_rx) =
@@ -254,6 +245,11 @@ impl Server {
worktree_ids
}
+ async fn ping(self: Arc<Server>, request: TypedEnvelope<proto::Ping>) -> tide::Result<()> {
+ self.peer.respond(request.receipt(), proto::Ack {}).await?;
+ Ok(())
+ }
+
async fn share_worktree(
self: Arc<Server>,
mut request: TypedEnvelope<proto::ShareWorktree>,
@@ -503,7 +499,9 @@ impl Server {
request: TypedEnvelope<proto::UpdateBuffer>,
) -> tide::Result<()> {
self.broadcast_in_worktree(request.payload.worktree_id, &request)
- .await
+ .await?;
+ self.peer.respond(request.receipt(), proto::Ack {}).await?;
+ Ok(())
}
async fn buffer_saved(
@@ -974,8 +972,7 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
let user_id = user_id.ok_or_else(|| anyhow!("user_id is not present on request. ensure auth::VerifyToken middleware is present"))?;
task::spawn(async move {
if let Some(stream) = upgrade_receiver.await {
- let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
- server.handle_connection(stream, addr, user_id).await;
+ server.handle_connection(Conn::new(WebSocketStream::from_raw_socket(stream, Role::Server, None).await), addr, user_id).await;
}
});
@@ -1009,17 +1006,25 @@ mod tests {
};
use async_std::{sync::RwLockReadGuard, task};
use gpui::TestAppContext;
- use postage::mpsc;
+ use parking_lot::Mutex;
+ use postage::{mpsc, watch};
use serde_json::json;
use sqlx::types::time::OffsetDateTime;
- use std::{path::Path, sync::Arc, time::Duration};
+ use std::{
+ path::Path,
+ sync::{
+ atomic::{AtomicBool, Ordering::SeqCst},
+ Arc,
+ },
+ time::Duration,
+ };
use zed::{
channel::{Channel, ChannelDetails, ChannelList},
editor::{Editor, Insert},
fs::{FakeFs, Fs as _},
language::LanguageRegistry,
- rpc::Client,
- settings, test,
+ rpc::{self, Client},
+ settings,
user::UserStore,
worktree::Worktree,
};
@@ -1469,7 +1474,7 @@ mod tests {
.await;
// Drop client B's connection and ensure client A observes client B leaving the worktree.
- client_b.disconnect().await.unwrap();
+ client_b.disconnect(&cx_b.to_async()).await.unwrap();
worktree_a
.condition(&cx_a, |tree, _| tree.peers().len() == 0)
.await;
@@ -1675,11 +1680,206 @@ mod tests {
);
}
+ #[gpui::test]
+ async fn test_chat_reconnection(mut cx_a: TestAppContext, mut cx_b: TestAppContext) {
+ cx_a.foreground().forbid_parking();
+
+ // Connect to a server as 2 clients.
+ let mut server = TestServer::start().await;
+ let (user_id_a, client_a) = server.create_client(&mut cx_a, "user_a").await;
+ let (user_id_b, client_b) = server.create_client(&mut cx_b, "user_b").await;
+ let mut status_b = client_b.status();
+
+ // Create an org that includes these 2 users.
+ let db = &server.app_state.db;
+ let org_id = db.create_org("Test Org", "test-org").await.unwrap();
+ db.add_org_member(org_id, user_id_a, false).await.unwrap();
+ db.add_org_member(org_id, user_id_b, false).await.unwrap();
+
+ // Create a channel that includes all the users.
+ let channel_id = db.create_org_channel(org_id, "test-channel").await.unwrap();
+ db.add_channel_member(channel_id, user_id_a, false)
+ .await
+ .unwrap();
+ db.add_channel_member(channel_id, user_id_b, false)
+ .await
+ .unwrap();
+ db.create_channel_message(
+ channel_id,
+ user_id_b,
+ "hello A, it's B.",
+ OffsetDateTime::now_utc(),
+ )
+ .await
+ .unwrap();
+
+ let user_store_a = Arc::new(UserStore::new(client_a.clone()));
+ let channels_a = cx_a.add_model(|cx| ChannelList::new(user_store_a, client_a, cx));
+ channels_a
+ .condition(&mut cx_a, |list, _| list.available_channels().is_some())
+ .await;
+
+ channels_a.read_with(&cx_a, |list, _| {
+ assert_eq!(
+ list.available_channels().unwrap(),
+ &[ChannelDetails {
+ id: channel_id.to_proto(),
+ name: "test-channel".to_string()
+ }]
+ )
+ });
+ let channel_a = channels_a.update(&mut cx_a, |this, cx| {
+ this.get_channel(channel_id.to_proto(), cx).unwrap()
+ });
+ channel_a.read_with(&cx_a, |channel, _| assert!(channel.messages().is_empty()));
+ channel_a
+ .condition(&cx_a, |channel, _| {
+ channel_messages(channel)
+ == [("user_b".to_string(), "hello A, it's B.".to_string())]
+ })
+ .await;
+
+ let user_store_b = Arc::new(UserStore::new(client_b.clone()));
+ let channels_b = cx_b.add_model(|cx| ChannelList::new(user_store_b, client_b, cx));
+ channels_b
+ .condition(&mut cx_b, |list, _| list.available_channels().is_some())
+ .await;
+ channels_b.read_with(&cx_b, |list, _| {
+ assert_eq!(
+ list.available_channels().unwrap(),
+ &[ChannelDetails {
+ id: channel_id.to_proto(),
+ name: "test-channel".to_string()
+ }]
+ )
+ });
+
+ let channel_b = channels_b.update(&mut cx_b, |this, cx| {
+ this.get_channel(channel_id.to_proto(), cx).unwrap()
+ });
+ channel_b.read_with(&cx_b, |channel, _| assert!(channel.messages().is_empty()));
+ channel_b
+ .condition(&cx_b, |channel, _| {
+ channel_messages(channel)
+ == [("user_b".to_string(), "hello A, it's B.".to_string())]
+ })
+ .await;
+
+ // Disconnect client B, ensuring we can still access its cached channel data.
+ server.forbid_connections();
+ server.disconnect_client(user_id_b);
+ while !matches!(
+ status_b.recv().await,
+ Some(rpc::Status::ReconnectionError { .. })
+ ) {}
+
+ channels_b.read_with(&cx_b, |channels, _| {
+ assert_eq!(
+ channels.available_channels().unwrap(),
+ [ChannelDetails {
+ id: channel_id.to_proto(),
+ name: "test-channel".to_string()
+ }]
+ )
+ });
+ channel_b.read_with(&cx_b, |channel, _| {
+ assert_eq!(
+ channel_messages(channel),
+ [("user_b".to_string(), "hello A, it's B.".to_string())]
+ )
+ });
+
+ // Send a message from client A while B is disconnected.
+ channel_a
+ .update(&mut cx_a, |channel, cx| {
+ channel
+ .send_message("oh, hi B.".to_string(), cx)
+ .unwrap()
+ .detach();
+ let task = channel.send_message("sup".to_string(), cx).unwrap();
+ assert_eq!(
+ channel
+ .pending_messages()
+ .iter()
+ .map(|m| &m.body)
+ .collect::<Vec<_>>(),
+ &["oh, hi B.", "sup"]
+ );
+ task
+ })
+ .await
+ .unwrap();
+
+ // Give client B a chance to reconnect.
+ server.allow_connections();
+ cx_b.foreground().advance_clock(Duration::from_secs(10));
+
+ // Verify that B sees the new messages upon reconnection.
+ channel_b
+ .condition(&cx_b, |channel, _| {
+ channel_messages(channel)
+ == [
+ ("user_b".to_string(), "hello A, it's B.".to_string()),
+ ("user_a".to_string(), "oh, hi B.".to_string()),
+ ("user_a".to_string(), "sup".to_string()),
+ ]
+ })
+ .await;
+
+ // Ensure client A and B can communicate normally after reconnection.
+ channel_a
+ .update(&mut cx_a, |channel, cx| {
+ channel.send_message("you online?".to_string(), cx).unwrap()
+ })
+ .await
+ .unwrap();
+ channel_b
+ .condition(&cx_b, |channel, _| {
+ channel_messages(channel)
+ == [
+ ("user_b".to_string(), "hello A, it's B.".to_string()),
+ ("user_a".to_string(), "oh, hi B.".to_string()),
+ ("user_a".to_string(), "sup".to_string()),
+ ("user_a".to_string(), "you online?".to_string()),
+ ]
+ })
+ .await;
+
+ channel_b
+ .update(&mut cx_b, |channel, cx| {
+ channel.send_message("yep".to_string(), cx).unwrap()
+ })
+ .await
+ .unwrap();
+ channel_a
+ .condition(&cx_a, |channel, _| {
+ channel_messages(channel)
+ == [
+ ("user_b".to_string(), "hello A, it's B.".to_string()),
+ ("user_a".to_string(), "oh, hi B.".to_string()),
+ ("user_a".to_string(), "sup".to_string()),
+ ("user_a".to_string(), "you online?".to_string()),
+ ("user_b".to_string(), "yep".to_string()),
+ ]
+ })
+ .await;
+
+ fn channel_messages(channel: &Channel) -> Vec<(String, String)> {
+ channel
+ .messages()
+ .cursor::<(), ()>()
+ .map(|m| (m.sender.github_login.clone(), m.body.clone()))
+ .collect()
+ }
+ }
+
struct TestServer {
peer: Arc<Peer>,
app_state: Arc<AppState>,
server: Arc<Server>,
notifications: mpsc::Receiver<()>,
+ connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
+ forbid_connections: Arc<AtomicBool>,
_test_db: TestDb,
}
@@ -1695,6 +1895,8 @@ mod tests {
app_state,
server,
notifications: notifications.1,
+ connection_killers: Default::default(),
+ forbid_connections: Default::default(),
_test_db: test_db,
}
}
@@ -1704,20 +1906,67 @@ mod tests {
cx: &mut TestAppContext,
name: &str,
) -> (UserId, Arc<Client>) {
- let user_id = self.app_state.db.create_user(name, false).await.unwrap();
- let client = Client::new();
- let (client_conn, server_conn) = test::Channel::bidirectional();
- cx.background()
- .spawn(
- self.server
- .handle_connection(server_conn, name.to_string(), user_id),
- )
- .detach();
+ let client_user_id = self.app_state.db.create_user(name, false).await.unwrap();
+ let client_name = name.to_string();
+ let mut client = Client::new();
+ let server = self.server.clone();
+ let connection_killers = self.connection_killers.clone();
+ let forbid_connections = self.forbid_connections.clone();
+ Arc::get_mut(&mut client)
+ .unwrap()
+ .set_login_and_connect_callbacks(
+ move |cx| {
+ cx.spawn(|_| async move {
+ let access_token = "the-token".to_string();
+ Ok((client_user_id.0 as u64, access_token))
+ })
+ },
+ move |user_id, access_token, cx| {
+ assert_eq!(user_id, client_user_id.0 as u64);
+ assert_eq!(access_token, "the-token");
+
+ let server = server.clone();
+ let connection_killers = connection_killers.clone();
+ let forbid_connections = forbid_connections.clone();
+ let client_name = client_name.clone();
+ cx.spawn(move |cx| async move {
+ if forbid_connections.load(SeqCst) {
+ Err(anyhow!("server is forbidding connections"))
+ } else {
+ let (client_conn, server_conn, kill_conn) = Conn::in_memory();
+ connection_killers.lock().insert(client_user_id, kill_conn);
+ cx.background()
+ .spawn(server.handle_connection(
+ server_conn,
+ client_name,
+ client_user_id,
+ ))
+ .detach();
+ Ok(client_conn)
+ }
+ })
+ },
+ );
+
client
- .add_connection(user_id.to_proto(), client_conn, &cx.to_async())
+ .authenticate_and_connect(&cx.to_async())
.await
.unwrap();
- (user_id, client)
+ (client_user_id, client)
+ }
+
+ fn disconnect_client(&self, user_id: UserId) {
+ if let Some(mut kill_conn) = self.connection_killers.lock().remove(&user_id) {
+ let _ = kill_conn.try_send(Some(()));
+ }
+ }
+
+ fn forbid_connections(&self) {
+ self.forbid_connections.store(true, SeqCst);
+ }
+
+ fn allow_connections(&self) {
+ self.forbid_connections.store(false, SeqCst);
}
async fn build_app_state(test_db: &TestDb) -> Arc<AppState> {
@@ -11,6 +11,7 @@ use gpui::{
use postage::prelude::Stream;
use std::{
collections::{HashMap, HashSet},
+ mem,
ops::Range,
sync::Arc,
};
@@ -71,7 +72,7 @@ pub enum ChannelListEvent {}
#[derive(Clone, Debug, PartialEq)]
pub enum ChannelEvent {
- MessagesAdded {
+ MessagesUpdated {
old_range: Range<usize>,
new_count: usize,
},
@@ -87,36 +88,47 @@ impl ChannelList {
rpc: Arc<rpc::Client>,
cx: &mut ModelContext<Self>,
) -> Self {
- let _task = cx.spawn(|this, mut cx| {
+ let _task = cx.spawn_weak(|this, mut cx| {
let rpc = rpc.clone();
async move {
- let mut user_id = rpc.user_id();
- loop {
- let available_channels = if user_id.recv().await.unwrap().is_some() {
- Some(
- rpc.request(proto::GetChannels {})
+ let mut status = rpc.status();
+ while let Some((status, this)) = status.recv().await.zip(this.upgrade(&cx)) {
+ match status {
+ rpc::Status::Connected { .. } => {
+ let response = rpc
+ .request(proto::GetChannels {})
.await
- .context("failed to fetch available channels")?
- .channels
- .into_iter()
- .map(Into::into)
- .collect(),
- )
- } else {
- None
- };
-
- this.update(&mut cx, |this, cx| {
- if available_channels.is_none() {
- if this.available_channels.is_none() {
- return;
- }
- this.channels.clear();
+ .context("failed to fetch available channels")?;
+ this.update(&mut cx, |this, cx| {
+ this.available_channels =
+ Some(response.channels.into_iter().map(Into::into).collect());
+
+ let mut to_remove = Vec::new();
+ for (channel_id, channel) in &this.channels {
+ if let Some(channel) = channel.upgrade(cx) {
+ channel.update(cx, |channel, cx| channel.rejoin(cx))
+ } else {
+ to_remove.push(*channel_id);
+ }
+ }
+
+ for channel_id in to_remove {
+ this.channels.remove(&channel_id);
+ }
+ cx.notify();
+ });
}
- this.available_channels = available_channels;
- cx.notify();
- });
+ rpc::Status::Disconnected { .. } => {
+ this.update(&mut cx, |this, cx| {
+ this.available_channels = None;
+ this.channels.clear();
+ cx.notify();
+ });
+ }
+ _ => {}
+ }
}
+ Ok(())
}
.log_err()
});
@@ -285,6 +297,43 @@ impl Channel {
false
}
+ pub fn rejoin(&mut self, cx: &mut ModelContext<Self>) {
+ let user_store = self.user_store.clone();
+ let rpc = self.rpc.clone();
+ let channel_id = self.details.id;
+ cx.spawn(|channel, mut cx| {
+ async move {
+ let response = rpc.request(proto::JoinChannel { channel_id }).await?;
+ let messages = messages_from_proto(response.messages, &user_store).await?;
+ let loaded_all_messages = response.done;
+
+ channel.update(&mut cx, |channel, cx| {
+ if let Some((first_new_message, last_old_message)) =
+ messages.first().zip(channel.messages.last())
+ {
+ if first_new_message.id > last_old_message.id {
+ let old_messages = mem::take(&mut channel.messages);
+ cx.emit(ChannelEvent::MessagesUpdated {
+ old_range: 0..old_messages.summary().count,
+ new_count: 0,
+ });
+ channel.loaded_all_messages = loaded_all_messages;
+ }
+ }
+
+ channel.insert_messages(messages, cx);
+ if loaded_all_messages {
+ channel.loaded_all_messages = loaded_all_messages;
+ }
+ });
+
+ Ok(())
+ }
+ .log_err()
+ })
+ .detach();
+ }
+
pub fn message_count(&self) -> usize {
self.messages.summary().count
}
@@ -350,7 +399,7 @@ impl Channel {
drop(old_cursor);
self.messages = new_messages;
- cx.emit(ChannelEvent::MessagesAdded {
+ cx.emit(ChannelEvent::MessagesUpdated {
old_range: start_ix..end_ix,
new_count,
});
@@ -446,22 +495,21 @@ impl<'a> sum_tree::SeekDimension<'a, ChannelMessageSummary> for Count {
#[cfg(test)]
mod tests {
use super::*;
+ use crate::test::FakeServer;
use gpui::TestAppContext;
- use postage::mpsc::Receiver;
- use zrpc::{test::Channel, ConnectionId, Peer, Receipt};
#[gpui::test]
async fn test_channel_messages(mut cx: TestAppContext) {
let user_id = 5;
- let client = Client::new();
- let mut server = FakeServer::for_client(user_id, &client, &cx).await;
+ let mut client = Client::new();
+ let server = FakeServer::for_client(user_id, &mut client, &cx).await;
let user_store = Arc::new(UserStore::new(client.clone()));
let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx));
channel_list.read_with(&cx, |list, _| assert_eq!(list.available_channels(), None));
// Get the available channels.
- let get_channels = server.receive::<proto::GetChannels>().await;
+ let get_channels = server.receive::<proto::GetChannels>().await.unwrap();
server
.respond(
get_channels.receipt(),
@@ -492,7 +540,7 @@ mod tests {
})
.unwrap();
channel.read_with(&cx, |channel, _| assert!(channel.messages().is_empty()));
- let join_channel = server.receive::<proto::JoinChannel>().await;
+ let join_channel = server.receive::<proto::JoinChannel>().await.unwrap();
server
.respond(
join_channel.receipt(),
@@ -517,7 +565,7 @@ mod tests {
.await;
// Client requests all users for the received messages
- let mut get_users = server.receive::<proto::GetUsers>().await;
+ let mut get_users = server.receive::<proto::GetUsers>().await.unwrap();
get_users.payload.user_ids.sort();
assert_eq!(get_users.payload.user_ids, vec![5, 6]);
server
@@ -542,7 +590,7 @@ mod tests {
assert_eq!(
channel.next_event(&cx).await,
- ChannelEvent::MessagesAdded {
+ ChannelEvent::MessagesUpdated {
old_range: 0..0,
new_count: 2,
}
@@ -574,7 +622,7 @@ mod tests {
.await;
// Client requests user for message since they haven't seen them yet
- let get_users = server.receive::<proto::GetUsers>().await;
+ let get_users = server.receive::<proto::GetUsers>().await.unwrap();
assert_eq!(get_users.payload.user_ids, vec![7]);
server
.respond(
@@ -591,7 +639,7 @@ mod tests {
assert_eq!(
channel.next_event(&cx).await,
- ChannelEvent::MessagesAdded {
+ ChannelEvent::MessagesUpdated {
old_range: 2..2,
new_count: 1,
}
@@ -610,7 +658,7 @@ mod tests {
channel.update(&mut cx, |channel, cx| {
assert!(channel.load_more_messages(cx));
});
- let get_messages = server.receive::<proto::GetChannelMessages>().await;
+ let get_messages = server.receive::<proto::GetChannelMessages>().await.unwrap();
assert_eq!(get_messages.payload.channel_id, 5);
assert_eq!(get_messages.payload.before_message_id, 10);
server
@@ -638,7 +686,7 @@ mod tests {
assert_eq!(
channel.next_event(&cx).await,
- ChannelEvent::MessagesAdded {
+ ChannelEvent::MessagesUpdated {
old_range: 0..0,
new_count: 2,
}
@@ -656,53 +704,4 @@ mod tests {
);
});
}
-
- struct FakeServer {
- peer: Arc<Peer>,
- incoming: Receiver<Box<dyn proto::AnyTypedEnvelope>>,
- connection_id: ConnectionId,
- }
-
- impl FakeServer {
- async fn for_client(user_id: u64, client: &Arc<Client>, cx: &TestAppContext) -> Self {
- let (client_conn, server_conn) = Channel::bidirectional();
- let peer = Peer::new();
- let (connection_id, io, incoming) = peer.add_connection(server_conn).await;
- cx.background().spawn(io).detach();
-
- client
- .add_connection(user_id, client_conn, &cx.to_async())
- .await
- .unwrap();
-
- Self {
- peer,
- incoming,
- connection_id,
- }
- }
-
- async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
- self.peer.send(self.connection_id, message).await.unwrap();
- }
-
- async fn receive<M: proto::EnvelopedMessage>(&mut self) -> TypedEnvelope<M> {
- *self
- .incoming
- .recv()
- .await
- .unwrap()
- .into_any()
- .downcast::<TypedEnvelope<M>>()
- .unwrap()
- }
-
- async fn respond<T: proto::RequestMessage>(
- &self,
- receipt: Receipt<T>,
- response: T::Response,
- ) {
- self.peer.respond(receipt, response).await.unwrap()
- }
- }
}
@@ -3,7 +3,7 @@ use std::sync::Arc;
use crate::{
channel::{Channel, ChannelEvent, ChannelList, ChannelMessage},
editor::Editor,
- rpc::Client,
+ rpc::{self, Client},
theme,
util::{ResultExt, TryFutureExt},
Settings,
@@ -14,10 +14,10 @@ use gpui::{
keymap::Binding,
platform::CursorStyle,
views::{ItemType, Select, SelectStyle},
- AppContext, Entity, ModelHandle, MutableAppContext, RenderContext, Subscription, View,
+ AppContext, Entity, ModelHandle, MutableAppContext, RenderContext, Subscription, Task, View,
ViewContext, ViewHandle,
};
-use postage::watch;
+use postage::{prelude::Stream, watch};
use time::{OffsetDateTime, UtcOffset};
const MESSAGE_LOADING_THRESHOLD: usize = 50;
@@ -31,6 +31,7 @@ pub struct ChatPanel {
channel_select: ViewHandle<Select>,
settings: watch::Receiver<Settings>,
local_timezone: UtcOffset,
+ _observe_status: Task<()>,
}
pub enum Event {}
@@ -98,6 +99,14 @@ impl ChatPanel {
cx.dispatch_action(LoadMoreMessages);
}
});
+ let _observe_status = cx.spawn(|this, mut cx| {
+ let mut status = rpc.status();
+ async move {
+ while let Some(_) = status.recv().await {
+ this.update(&mut cx, |_, cx| cx.notify());
+ }
+ }
+ });
let mut this = Self {
rpc,
@@ -108,6 +117,7 @@ impl ChatPanel {
channel_select,
settings,
local_timezone: cx.platform().local_timezone(),
+ _observe_status,
};
this.init_active_channel(cx);
@@ -153,6 +163,7 @@ impl ChatPanel {
if let Some(active_channel) = active_channel {
self.set_active_channel(active_channel, cx);
} else {
+ self.message_list.reset(0);
self.active_channel = None;
}
@@ -183,7 +194,7 @@ impl ChatPanel {
cx: &mut ViewContext<Self>,
) {
match event {
- ChannelEvent::MessagesAdded {
+ ChannelEvent::MessagesUpdated {
old_range,
new_count,
} => {
@@ -357,10 +368,6 @@ impl ChatPanel {
})
}
}
-
- fn is_signed_in(&self) -> bool {
- self.rpc.user_id().borrow().is_some()
- }
}
impl Entity for ChatPanel {
@@ -374,10 +381,9 @@ impl View for ChatPanel {
fn render(&mut self, cx: &mut RenderContext<Self>) -> ElementBox {
let theme = &self.settings.borrow().theme;
- let element = if self.is_signed_in() {
- self.render_channel()
- } else {
- self.render_sign_in_prompt(cx)
+ let element = match *self.rpc.status().borrow() {
+ rpc::Status::Connected { .. } => self.render_channel(),
+ _ => self.render_sign_in_prompt(cx),
};
ConstrainedBox::new(
Container::new(element)
@@ -389,7 +395,7 @@ impl View for ChatPanel {
}
fn on_focus(&mut self, cx: &mut ViewContext<Self>) {
- if self.is_signed_in() {
+ if matches!(*self.rpc.status().borrow(), rpc::Status::Connected { .. }) {
cx.focus(&self.input_editor);
}
}
@@ -2695,14 +2695,7 @@ impl<'a> Into<proto::operation::Edit> for &'a EditOperation {
impl<'a> Into<proto::Anchor> for &'a Anchor {
fn into(self) -> proto::Anchor {
proto::Anchor {
- version: self
- .version
- .iter()
- .map(|entry| proto::VectorClockEntry {
- replica_id: entry.replica_id as u32,
- timestamp: entry.value,
- })
- .collect(),
+ version: (&self.version).into(),
offset: self.offset as u64,
bias: match self.bias {
Bias::Left => proto::anchor::Bias::Left as i32,
@@ -1,24 +1,24 @@
use crate::util::ResultExt;
use anyhow::{anyhow, Context, Result};
use async_tungstenite::tungstenite::http::Request;
-use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
use gpui::{AsyncAppContext, Entity, ModelContext, Task};
use lazy_static::lazy_static;
use parking_lot::RwLock;
-use postage::prelude::Stream;
-use postage::sink::Sink;
-use postage::watch;
-use std::any::TypeId;
-use std::collections::HashMap;
-use std::sync::Weak;
-use std::time::{Duration, Instant};
-use std::{convert::TryFrom, future::Future, sync::Arc};
+use postage::{prelude::Stream, watch};
+use rand::prelude::*;
+use std::{
+ any::TypeId,
+ collections::HashMap,
+ convert::TryFrom,
+ future::Future,
+ sync::{Arc, Weak},
+ time::{Duration, Instant},
+};
use surf::Url;
-use zrpc::proto::{AnyTypedEnvelope, EntityMessage};
pub use zrpc::{proto, ConnectionId, PeerId, TypedEnvelope};
use zrpc::{
- proto::{EnvelopedMessage, RequestMessage},
- Peer, Receipt,
+ proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
+ Conn, Peer, Receipt,
};
lazy_static! {
@@ -29,25 +29,55 @@ lazy_static! {
pub struct Client {
peer: Arc<Peer>,
state: RwLock<ClientState>,
+ auth_callback: Option<
+ Box<dyn 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<(u64, String)>>>,
+ >,
+ connect_callback: Option<
+ Box<dyn 'static + Send + Sync + Fn(u64, &str, &AsyncAppContext) -> Task<Result<Conn>>>,
+ >,
+}
+
+#[derive(Copy, Clone, Debug)]
+pub enum Status {
+ Disconnected,
+ Authenticating,
+ Connecting {
+ user_id: u64,
+ },
+ ConnectionError,
+ Connected {
+ connection_id: ConnectionId,
+ user_id: u64,
+ },
+ ConnectionLost,
+ Reauthenticating,
+ Reconnecting {
+ user_id: u64,
+ },
+ ReconnectionError {
+ next_reconnection: Instant,
+ },
}
struct ClientState {
- connection_id: Option<ConnectionId>,
- user_id: (watch::Sender<Option<u64>>, watch::Receiver<Option<u64>>),
+ status: (watch::Sender<Status>, watch::Receiver<Status>),
entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
model_handlers: HashMap<
(TypeId, u64),
Box<dyn Send + Sync + FnMut(Box<dyn AnyTypedEnvelope>, &mut AsyncAppContext)>,
>,
+ _maintain_connection: Option<Task<()>>,
+ heartbeat_interval: Duration,
}
impl Default for ClientState {
fn default() -> Self {
Self {
- connection_id: Default::default(),
- user_id: watch::channel(),
+ status: watch::channel_with(Status::Disconnected),
entity_id_extractors: Default::default(),
model_handlers: Default::default(),
+ _maintain_connection: None,
+ heartbeat_interval: Duration::from_secs(5),
}
}
}
@@ -77,11 +107,71 @@ impl Client {
Arc::new(Self {
peer: Peer::new(),
state: Default::default(),
+ auth_callback: None,
+ connect_callback: None,
})
}
- pub fn user_id(&self) -> watch::Receiver<Option<u64>> {
- self.state.read().user_id.1.clone()
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn set_login_and_connect_callbacks<Login, Connect>(
+ &mut self,
+ login: Login,
+ connect: Connect,
+ ) where
+ Login: 'static + Send + Sync + Fn(&AsyncAppContext) -> Task<Result<(u64, String)>>,
+ Connect: 'static + Send + Sync + Fn(u64, &str, &AsyncAppContext) -> Task<Result<Conn>>,
+ {
+ self.auth_callback = Some(Box::new(login));
+ self.connect_callback = Some(Box::new(connect));
+ }
+
+ pub fn status(&self) -> watch::Receiver<Status> {
+ self.state.read().status.1.clone()
+ }
+
+ fn set_status(self: &Arc<Self>, status: Status, cx: &AsyncAppContext) {
+ let mut state = self.state.write();
+ *state.status.0.borrow_mut() = status;
+
+ match status {
+ Status::Connected { .. } => {
+ let heartbeat_interval = state.heartbeat_interval;
+ let this = self.clone();
+ let foreground = cx.foreground();
+ state._maintain_connection = Some(cx.foreground().spawn(async move {
+ loop {
+ foreground.timer(heartbeat_interval).await;
+ this.request(proto::Ping {}).await.unwrap();
+ }
+ }));
+ }
+ Status::ConnectionLost => {
+ let this = self.clone();
+ let foreground = cx.foreground();
+ let heartbeat_interval = state.heartbeat_interval;
+ state._maintain_connection = Some(cx.spawn(|cx| async move {
+ let mut rng = StdRng::from_entropy();
+ let mut delay = Duration::from_millis(100);
+ while let Err(error) = this.authenticate_and_connect(&cx).await {
+ log::error!("failed to connect {}", error);
+ this.set_status(
+ Status::ReconnectionError {
+ next_reconnection: Instant::now() + delay,
+ },
+ &cx,
+ );
+ foreground.timer(delay).await;
+ delay = delay
+ .mul_f32(rng.gen_range(1.0..=2.0))
+ .min(heartbeat_interval);
+ }
+ }));
+ }
+ Status::Disconnected => {
+ state._maintain_connection.take();
+ }
+ _ => {}
+ }
}
pub fn subscribe_from_model<T, M, F>(
@@ -141,56 +231,57 @@ impl Client {
self: &Arc<Self>,
cx: &AsyncAppContext,
) -> anyhow::Result<()> {
- if self.state.read().connection_id.is_some() {
- return Ok(());
- }
-
- let (user_id, access_token) = Self::login(cx.platform(), &cx.background()).await?;
- let user_id = user_id.parse::<u64>()?;
- let request =
- Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
+ let was_disconnected = match *self.status().borrow() {
+ Status::Disconnected => true,
+ Status::ConnectionError | Status::ConnectionLost | Status::ReconnectionError { .. } => {
+ false
+ }
+ Status::Connected { .. }
+ | Status::Connecting { .. }
+ | Status::Reconnecting { .. }
+ | Status::Authenticating
+ | Status::Reauthenticating => return Ok(()),
+ };
- if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
- let stream = smol::net::TcpStream::connect(host).await?;
- let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
- let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
- .await
- .context("websocket handshake")?;
- self.add_connection(user_id, stream, cx).await?;
- } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
- let stream = smol::net::TcpStream::connect(host).await?;
- let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
- let (stream, _) = async_tungstenite::client_async(request, stream)
- .await
- .context("websocket handshake")?;
- self.add_connection(user_id, stream, cx).await?;
+ if was_disconnected {
+ self.set_status(Status::Authenticating, cx);
} else {
- return Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))?;
+ self.set_status(Status::Reauthenticating, cx)
+ }
+
+ let (user_id, access_token) = match self.authenticate(&cx).await {
+ Ok(result) => result,
+ Err(err) => {
+ self.set_status(Status::ConnectionError, cx);
+ return Err(err);
+ }
};
- log::info!("connected to rpc address {}", *ZED_SERVER_URL);
- Ok(())
+ if was_disconnected {
+ self.set_status(Status::Connecting { user_id }, cx);
+ } else {
+ self.set_status(Status::Reconnecting { user_id }, cx);
+ }
+ match self.connect(user_id, &access_token, cx).await {
+ Ok(conn) => {
+ log::info!("connected to rpc address {}", *ZED_SERVER_URL);
+ self.set_connection(user_id, conn, cx).await;
+ Ok(())
+ }
+ Err(err) => {
+ self.set_status(Status::ConnectionError, cx);
+ Err(err)
+ }
+ }
}
- pub async fn add_connection<Conn>(
- self: &Arc<Self>,
- user_id: u64,
- conn: Conn,
- cx: &AsyncAppContext,
- ) -> anyhow::Result<()>
- where
- Conn: 'static
- + futures::Sink<WebSocketMessage, Error = WebSocketError>
- + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
- + Unpin
- + Send,
- {
+ async fn set_connection(self: &Arc<Self>, user_id: u64, conn: Conn, cx: &AsyncAppContext) {
let (connection_id, handle_io, mut incoming) = self.peer.add_connection(conn).await;
- {
- let mut cx = cx.clone();
- let this = self.clone();
- cx.foreground()
- .spawn(async move {
+ cx.foreground()
+ .spawn({
+ let mut cx = cx.clone();
+ let this = self.clone();
+ async move {
while let Some(message) = incoming.recv().await {
let mut state = this.state.write();
if let Some(extract_entity_id) =
@@ -215,27 +306,90 @@ impl Client {
log::info!("unhandled message {}", message.payload_type_name());
}
}
- })
- .detach();
- }
- cx.background()
+ }
+ })
+ .detach();
+
+ self.set_status(
+ Status::Connected {
+ connection_id,
+ user_id,
+ },
+ cx,
+ );
+
+ let handle_io = cx.background().spawn(handle_io);
+ let this = self.clone();
+ let cx = cx.clone();
+ cx.foreground()
.spawn(async move {
- if let Err(error) = handle_io.await {
- log::error!("connection error: {:?}", error);
+ match handle_io.await {
+ Ok(()) => this.set_status(Status::Disconnected, &cx),
+ Err(err) => {
+ log::error!("connection error: {:?}", err);
+ this.set_status(Status::ConnectionLost, &cx);
+ }
}
})
.detach();
- let mut state = self.state.write();
- state.connection_id = Some(connection_id);
- state.user_id.0.send(Some(user_id)).await?;
- Ok(())
}
- pub fn login(
- platform: Arc<dyn gpui::Platform>,
- executor: &Arc<gpui::executor::Background>,
- ) -> Task<Result<(String, String)>> {
- let executor = executor.clone();
+ fn authenticate(self: &Arc<Self>, cx: &AsyncAppContext) -> Task<Result<(u64, String)>> {
+ if let Some(callback) = self.auth_callback.as_ref() {
+ callback(cx)
+ } else {
+ self.authenticate_with_browser(cx)
+ }
+ }
+
+ fn connect(
+ self: &Arc<Self>,
+ user_id: u64,
+ access_token: &str,
+ cx: &AsyncAppContext,
+ ) -> Task<Result<Conn>> {
+ if let Some(callback) = self.connect_callback.as_ref() {
+ callback(user_id, access_token, cx)
+ } else {
+ self.connect_with_websocket(user_id, access_token, cx)
+ }
+ }
+
+ fn connect_with_websocket(
+ self: &Arc<Self>,
+ user_id: u64,
+ access_token: &str,
+ cx: &AsyncAppContext,
+ ) -> Task<Result<Conn>> {
+ let request =
+ Request::builder().header("Authorization", format!("{} {}", user_id, access_token));
+ cx.background().spawn(async move {
+ if let Some(host) = ZED_SERVER_URL.strip_prefix("https://") {
+ let stream = smol::net::TcpStream::connect(host).await?;
+ let request = request.uri(format!("wss://{}/rpc", host)).body(())?;
+ let (stream, _) = async_tungstenite::async_tls::client_async_tls(request, stream)
+ .await
+ .context("websocket handshake")?;
+ Ok(Conn::new(stream))
+ } else if let Some(host) = ZED_SERVER_URL.strip_prefix("http://") {
+ let stream = smol::net::TcpStream::connect(host).await?;
+ let request = request.uri(format!("ws://{}/rpc", host)).body(())?;
+ let (stream, _) = async_tungstenite::client_async(request, stream)
+ .await
+ .context("websocket handshake")?;
+ Ok(Conn::new(stream))
+ } else {
+ Err(anyhow!("invalid server url: {}", *ZED_SERVER_URL))
+ }
+ })
+ }
+
+ pub fn authenticate_with_browser(
+ self: &Arc<Self>,
+ cx: &AsyncAppContext,
+ ) -> Task<Result<(u64, String)>> {
+ let platform = cx.platform();
+ let executor = cx.background();
executor.clone().spawn(async move {
if let Some((user_id, access_token)) = platform
.read_credentials(&ZED_SERVER_URL)
@@ -243,7 +397,7 @@ impl Client {
.flatten()
{
log::info!("already signed in. user_id: {}", user_id);
- return Ok((user_id, String::from_utf8(access_token).unwrap()));
+ return Ok((user_id.parse()?, String::from_utf8(access_token).unwrap()));
}
// Generate a pair of asymmetric encryption keys. The public key will be used by the
@@ -309,21 +463,23 @@ impl Client {
platform
.write_credentials(&ZED_SERVER_URL, &user_id, access_token.as_bytes())
.log_err();
- Ok((user_id.to_string(), access_token))
+ Ok((user_id.parse()?, access_token))
})
}
- pub async fn disconnect(&self) -> Result<()> {
+ pub async fn disconnect(self: &Arc<Self>, cx: &AsyncAppContext) -> Result<()> {
let conn_id = self.connection_id()?;
self.peer.disconnect(conn_id).await;
+ self.set_status(Status::Disconnected, cx);
Ok(())
}
fn connection_id(&self) -> Result<ConnectionId> {
- self.state
- .read()
- .connection_id
- .ok_or_else(|| anyhow!("not connected"))
+ if let Status::Connected { connection_id, .. } = *self.status().borrow() {
+ Ok(connection_id)
+ } else {
+ Err(anyhow!("not connected"))
+ }
}
pub async fn send<T: EnvelopedMessage>(&self, message: T) -> Result<()> {
@@ -343,35 +499,6 @@ impl Client {
}
}
-pub trait MessageHandler<'a, M: proto::EnvelopedMessage>: Clone {
- type Output: 'a + Future<Output = anyhow::Result<()>>;
-
- fn handle(
- &self,
- message: TypedEnvelope<M>,
- rpc: &'a Client,
- cx: &'a mut gpui::AsyncAppContext,
- ) -> Self::Output;
-}
-
-impl<'a, M, F, Fut> MessageHandler<'a, M> for F
-where
- M: proto::EnvelopedMessage,
- F: Clone + Fn(TypedEnvelope<M>, &'a Client, &'a mut gpui::AsyncAppContext) -> Fut,
- Fut: 'a + Future<Output = anyhow::Result<()>>,
-{
- type Output = Fut;
-
- fn handle(
- &self,
- message: TypedEnvelope<M>,
- rpc: &'a Client,
- cx: &'a mut gpui::AsyncAppContext,
- ) -> Self::Output {
- (self)(message, rpc, cx)
- }
-}
-
const WORKTREE_URL_PREFIX: &'static str = "zed://worktrees/";
pub fn encode_worktree_url(id: u64, access_token: &str) -> String {
@@ -396,13 +523,62 @@ const LOGIN_RESPONSE: &'static str = "
</html>
";
-#[test]
-fn test_encode_and_decode_worktree_url() {
- let url = encode_worktree_url(5, "deadbeef");
- assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
- assert_eq!(
- decode_worktree_url(&format!("\n {}\t", url)),
- Some((5, "deadbeef".to_string()))
- );
- assert_eq!(decode_worktree_url("not://the-right-format"), None);
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::test::FakeServer;
+ use gpui::TestAppContext;
+
+ #[gpui::test(iterations = 10)]
+ async fn test_heartbeat(cx: TestAppContext) {
+ cx.foreground().forbid_parking();
+
+ let user_id = 5;
+ let mut client = Client::new();
+ let server = FakeServer::for_client(user_id, &mut client, &cx).await;
+
+ cx.foreground().advance_clock(Duration::from_secs(10));
+ let ping = server.receive::<proto::Ping>().await.unwrap();
+ server.respond(ping.receipt(), proto::Ack {}).await;
+
+ cx.foreground().advance_clock(Duration::from_secs(10));
+ let ping = server.receive::<proto::Ping>().await.unwrap();
+ server.respond(ping.receipt(), proto::Ack {}).await;
+
+ client.disconnect(&cx.to_async()).await.unwrap();
+ assert!(server.receive::<proto::Ping>().await.is_err());
+ }
+
+ #[gpui::test(iterations = 10)]
+ async fn test_reconnection(cx: TestAppContext) {
+ cx.foreground().forbid_parking();
+
+ let user_id = 5;
+ let mut client = Client::new();
+ let server = FakeServer::for_client(user_id, &mut client, &cx).await;
+ let mut status = client.status();
+ assert!(matches!(
+ status.recv().await,
+ Some(Status::Connected { .. })
+ ));
+
+ server.forbid_connections();
+ server.disconnect().await;
+ while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {}
+
+ server.allow_connections();
+ cx.foreground().advance_clock(Duration::from_secs(10));
+ while !matches!(status.recv().await, Some(Status::Connected { .. })) {}
+ }
+
+ #[test]
+ fn test_encode_and_decode_worktree_url() {
+ let url = encode_worktree_url(5, "deadbeef");
+ assert_eq!(decode_worktree_url(&url), Some((5, "deadbeef".to_string())));
+ assert_eq!(
+ decode_worktree_url(&format!("\n {}\t", url)),
+ Some((5, "deadbeef".to_string()))
+ );
+ assert_eq!(decode_worktree_url("not://the-right-format"), None);
+ }
}
@@ -3,24 +3,27 @@ use crate::{
channel::ChannelList,
fs::RealFs,
language::LanguageRegistry,
- rpc,
+ rpc::{self, Client},
settings::{self, ThemeRegistry},
time::ReplicaId,
user::UserStore,
AppState,
};
-use gpui::{Entity, ModelHandle, MutableAppContext};
+use anyhow::{anyhow, Result};
+use gpui::{AsyncAppContext, Entity, ModelHandle, MutableAppContext, TestAppContext};
use parking_lot::Mutex;
+use postage::{mpsc, prelude::Stream as _};
use smol::channel;
use std::{
marker::PhantomData,
path::{Path, PathBuf},
- sync::Arc,
+ sync::{
+ atomic::{AtomicBool, Ordering::SeqCst},
+ Arc,
+ },
};
use tempdir::TempDir;
-
-#[cfg(feature = "test-support")]
-pub use zrpc::test::Channel;
+use zrpc::{proto, Conn, ConnectionId, Peer, Receipt, TypedEnvelope};
#[cfg(test)]
#[ctor::ctor]
@@ -195,3 +198,117 @@ impl<T: Entity> Observer<T> {
(observer, notify_rx)
}
}
+
+pub struct FakeServer {
+ peer: Arc<Peer>,
+ incoming: Mutex<Option<mpsc::Receiver<Box<dyn proto::AnyTypedEnvelope>>>>,
+ connection_id: Mutex<Option<ConnectionId>>,
+ forbid_connections: AtomicBool,
+}
+
+impl FakeServer {
+ pub async fn for_client(
+ client_user_id: u64,
+ client: &mut Arc<Client>,
+ cx: &TestAppContext,
+ ) -> Arc<Self> {
+ let result = Arc::new(Self {
+ peer: Peer::new(),
+ incoming: Default::default(),
+ connection_id: Default::default(),
+ forbid_connections: Default::default(),
+ });
+
+ Arc::get_mut(client)
+ .unwrap()
+ .set_login_and_connect_callbacks(
+ move |cx| {
+ cx.spawn(|_| async move {
+ let access_token = "the-token".to_string();
+ Ok((client_user_id, access_token))
+ })
+ },
+ {
+ let server = result.clone();
+ move |user_id, access_token, cx| {
+ assert_eq!(user_id, client_user_id);
+ assert_eq!(access_token, "the-token");
+ cx.spawn({
+ let server = server.clone();
+ move |cx| async move { server.connect(&cx).await }
+ })
+ }
+ },
+ );
+
+ client
+ .authenticate_and_connect(&cx.to_async())
+ .await
+ .unwrap();
+ result
+ }
+
+ pub async fn disconnect(&self) {
+ self.peer.disconnect(self.connection_id()).await;
+ self.connection_id.lock().take();
+ self.incoming.lock().take();
+ }
+
+ async fn connect(&self, cx: &AsyncAppContext) -> Result<Conn> {
+ if self.forbid_connections.load(SeqCst) {
+ Err(anyhow!("server is forbidding connections"))
+ } else {
+ let (client_conn, server_conn, _) = Conn::in_memory();
+ let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
+ cx.background().spawn(io).detach();
+ *self.incoming.lock() = Some(incoming);
+ *self.connection_id.lock() = Some(connection_id);
+ Ok(client_conn)
+ }
+ }
+
+ pub fn forbid_connections(&self) {
+ self.forbid_connections.store(true, SeqCst);
+ }
+
+ pub fn allow_connections(&self) {
+ self.forbid_connections.store(false, SeqCst);
+ }
+
+ pub async fn send<T: proto::EnvelopedMessage>(&self, message: T) {
+ self.peer.send(self.connection_id(), message).await.unwrap();
+ }
+
+ pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
+ let message = self
+ .incoming
+ .lock()
+ .as_mut()
+ .expect("not connected")
+ .recv()
+ .await
+ .ok_or_else(|| anyhow!("other half hung up"))?;
+ let type_name = message.payload_type_name();
+ Ok(*message
+ .into_any()
+ .downcast::<TypedEnvelope<M>>()
+ .unwrap_or_else(|_| {
+ panic!(
+ "fake server received unexpected message type: {:?}",
+ type_name
+ );
+ }))
+ }
+
+ pub async fn respond<T: proto::RequestMessage>(
+ &self,
+ receipt: Receipt<T>,
+ response: T::Response,
+ ) {
+ self.peer.respond(receipt, response).await.unwrap()
+ }
+
+ fn connection_id(&self) -> ConnectionId {
+ self.connection_id.lock().expect("not connected")
+ }
+}
@@ -234,6 +234,7 @@ impl Worktree {
.into_iter()
.map(|p| (PeerId(p.peer_id), p.replica_id as ReplicaId))
.collect(),
+ queued_operations: Default::default(),
languages,
_subscriptions,
})
@@ -656,6 +657,7 @@ pub struct LocalWorktree {
shared_buffers: HashMap<PeerId, HashMap<u64, ModelHandle<Buffer>>>,
peers: HashMap<PeerId, ReplicaId>,
languages: Arc<LanguageRegistry>,
+ queued_operations: Vec<(u64, Operation)>,
fs: Arc<dyn Fs>,
}
@@ -711,6 +713,7 @@ impl LocalWorktree {
poll_task: None,
open_buffers: Default::default(),
shared_buffers: Default::default(),
+ queued_operations: Default::default(),
peers: Default::default(),
languages,
fs,
@@ -1091,6 +1094,7 @@ pub struct RemoteWorktree {
open_buffers: HashMap<usize, RemoteBuffer>,
peers: HashMap<PeerId, ReplicaId>,
languages: Arc<LanguageRegistry>,
+ queued_operations: Vec<(u64, Operation)>,
_subscriptions: Vec<rpc::Subscription>,
}
@@ -1550,16 +1554,23 @@ impl File {
.map(|share| (share.rpc.clone(), share.remote_id)),
Worktree::Remote(worktree) => Some((worktree.rpc.clone(), worktree.remote_id)),
} {
- cx.spawn(|_, _| async move {
+ cx.spawn(|worktree, mut cx| async move {
if let Err(error) = rpc
- .send(proto::UpdateBuffer {
+ .request(proto::UpdateBuffer {
worktree_id: remote_id,
buffer_id,
- operations: Some(operation).iter().map(Into::into).collect(),
+ operations: vec![(&operation).into()],
})
.await
{
- log::error!("error sending buffer operation: {}", error);
+ worktree.update(&mut cx, |worktree, _| {
+ log::error!("error sending buffer operation: {}", error);
+ match worktree {
+ Worktree::Local(t) => &mut t.queued_operations,
+ Worktree::Remote(t) => &mut t.queued_operations,
+ }
+ .push((buffer_id, operation));
+ });
}
})
.detach();
@@ -1582,7 +1593,7 @@ impl File {
.await
{
log::error!("error closing remote buffer: {}", error);
- };
+ }
})
.detach();
}
@@ -6,9 +6,9 @@ message Envelope {
optional uint32 responding_to = 2;
optional uint32 original_sender_id = 3;
oneof payload {
- Error error = 4;
- Ping ping = 5;
- Pong pong = 6;
+ Ack ack = 4;
+ Error error = 5;
+ Ping ping = 6;
ShareWorktree share_worktree = 7;
ShareWorktreeResponse share_worktree_response = 8;
OpenWorktree open_worktree = 9;
@@ -40,13 +40,9 @@ message Envelope {
// Messages
-message Ping {
- int32 id = 1;
-}
+message Ping {}
-message Pong {
- int32 id = 2;
-}
+message Ack {}
message Error {
string message = 1;
@@ -0,0 +1,101 @@
+use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
+use futures::{channel::mpsc, SinkExt as _, Stream, StreamExt as _};
+use std::{io, task::Poll};
+
+pub struct Conn {
+ pub(crate) tx:
+ Box<dyn 'static + Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
+ pub(crate) rx: Box<
+ dyn 'static
+ + Send
+ + Unpin
+ + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
+ >,
+}
+
+impl Conn {
+ pub fn new<S>(stream: S) -> Self
+ where
+ S: 'static
+ + Send
+ + Unpin
+ + futures::Sink<WebSocketMessage, Error = WebSocketError>
+ + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>,
+ {
+ let (tx, rx) = stream.split();
+ Self {
+ tx: Box::new(tx),
+ rx: Box::new(rx),
+ }
+ }
+
+ pub async fn send(&mut self, message: WebSocketMessage) -> Result<(), WebSocketError> {
+ self.tx.send(message).await
+ }
+
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn in_memory() -> (Self, Self, postage::watch::Sender<Option<()>>) {
+ let (kill_tx, mut kill_rx) = postage::watch::channel_with(None);
+ postage::stream::Stream::try_recv(&mut kill_rx).unwrap();
+
+ let (a_tx, a_rx) = Self::channel(kill_rx.clone());
+ let (b_tx, b_rx) = Self::channel(kill_rx);
+ (
+ Self { tx: a_tx, rx: b_rx },
+ Self { tx: b_tx, rx: a_rx },
+ kill_tx,
+ )
+ }
+
+ #[cfg(any(test, feature = "test-support"))]
+ fn channel(
+ kill_rx: postage::watch::Receiver<Option<()>>,
+ ) -> (
+ Box<dyn Send + Unpin + futures::Sink<WebSocketMessage, Error = WebSocketError>>,
+ Box<dyn Send + Unpin + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>>,
+ ) {
+ use futures::{future, SinkExt as _};
+ use io::{Error, ErrorKind};
+
+ let (tx, rx) = mpsc::unbounded::<WebSocketMessage>();
+ let tx = tx
+ .sink_map_err(|e| WebSocketError::from(Error::new(ErrorKind::Other, e)))
+ .with({
+ let kill_rx = kill_rx.clone();
+ move |msg| {
+ if kill_rx.borrow().is_none() {
+ future::ready(Ok(msg))
+ } else {
+ future::ready(Err(Error::new(ErrorKind::Other, "connection killed").into()))
+ }
+ }
+ });
+ let rx = KillableReceiver { kill_rx, rx };
+
+ (Box::new(tx), Box::new(rx))
+ }
+}
+
+struct KillableReceiver {
+ rx: mpsc::UnboundedReceiver<WebSocketMessage>,
+ kill_rx: postage::watch::Receiver<Option<()>>,
+}
+
+impl Stream for KillableReceiver {
+ type Item = Result<WebSocketMessage, WebSocketError>;
+
+ fn poll_next(
+ mut self: std::pin::Pin<&mut Self>,
+ cx: &mut std::task::Context<'_>,
+ ) -> Poll<Option<Self::Item>> {
+ if let Poll::Ready(Some(Some(()))) = self.kill_rx.poll_next_unpin(cx) {
+ Poll::Ready(Some(Err(io::Error::new(
+ io::ErrorKind::Other,
+ "connection killed",
+ )
+ .into())))
+ } else {
+ self.rx.poll_next_unpin(cx).map(|value| value.map(Ok))
+ }
+ }
+}
@@ -1,7 +1,6 @@
pub mod auth;
+mod conn;
mod peer;
pub mod proto;
-#[cfg(any(test, feature = "test-support"))]
-pub mod test;
-
+pub use conn::Conn;
pub use peer::*;
@@ -1,8 +1,8 @@
-use crate::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
+use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage};
+use super::Conn;
use anyhow::{anyhow, Context, Result};
use async_lock::{Mutex, RwLock};
-use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
-use futures::{FutureExt, StreamExt};
+use futures::FutureExt as _;
use postage::{
mpsc,
prelude::{Sink as _, Stream as _},
@@ -98,21 +98,14 @@ impl Peer {
})
}
- pub async fn add_connection<Conn>(
+ pub async fn add_connection(
self: &Arc<Self>,
conn: Conn,
) -> (
ConnectionId,
impl Future<Output = anyhow::Result<()>> + Send,
mpsc::Receiver<Box<dyn AnyTypedEnvelope>>,
- )
- where
- Conn: futures::Sink<WebSocketMessage, Error = WebSocketError>
- + futures::Stream<Item = Result<WebSocketMessage, WebSocketError>>
- + Send
- + Unpin,
- {
- let (tx, rx) = conn.split();
+ ) {
let connection_id = ConnectionId(
self.next_connection_id
.fetch_add(1, atomic::Ordering::SeqCst),
@@ -124,9 +117,10 @@ impl Peer {
next_message_id: Default::default(),
response_channels: Default::default(),
};
- let mut writer = MessageStream::new(tx);
- let mut reader = MessageStream::new(rx);
+ let mut writer = MessageStream::new(conn.tx);
+ let mut reader = MessageStream::new(conn.rx);
+ let this = self.clone();
let response_channels = connection.response_channels.clone();
let handle_io = async move {
loop {
@@ -147,6 +141,7 @@ impl Peer {
if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) {
if incoming_tx.send(envelope).await.is_err() {
response_channels.lock().await.clear();
+ this.connections.write().await.remove(&connection_id);
return Ok(())
}
} else {
@@ -158,6 +153,7 @@ impl Peer {
}
Err(error) => {
response_channels.lock().await.clear();
+ this.connections.write().await.remove(&connection_id);
Err(error).context("received invalid RPC message")?;
}
},
@@ -165,11 +161,13 @@ impl Peer {
Some(outgoing) => {
if let Err(result) = writer.write_message(&outgoing).await {
response_channels.lock().await.clear();
+ this.connections.write().await.remove(&connection_id);
Err(result).context("failed to write RPC message")?;
}
}
None => {
response_channels.lock().await.clear();
+ this.connections.write().await.remove(&connection_id);
return Ok(())
}
}
@@ -342,7 +340,9 @@ impl Peer {
#[cfg(test)]
mod tests {
use super::*;
- use crate::{test, TypedEnvelope};
+ use crate::TypedEnvelope;
+ use async_tungstenite::tungstenite::Message as WebSocketMessage;
+ use futures::StreamExt as _;
#[test]
fn test_request_response() {
@@ -352,12 +352,12 @@ mod tests {
let client1 = Peer::new();
let client2 = Peer::new();
- let (client1_to_server_conn, server_to_client_1_conn) = test::Channel::bidirectional();
+ let (client1_to_server_conn, server_to_client_1_conn, _) = Conn::in_memory();
let (client1_conn_id, io_task1, _) =
client1.add_connection(client1_to_server_conn).await;
let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await;
- let (client2_to_server_conn, server_to_client_2_conn) = test::Channel::bidirectional();
+ let (client2_to_server_conn, server_to_client_2_conn, _) = Conn::in_memory();
let (client2_conn_id, io_task3, _) =
client2.add_connection(client2_to_server_conn).await;
let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await;
@@ -371,18 +371,18 @@ mod tests {
assert_eq!(
client1
- .request(client1_conn_id, proto::Ping { id: 1 },)
+ .request(client1_conn_id, proto::Ping {},)
.await
.unwrap(),
- proto::Pong { id: 1 }
+ proto::Ack {}
);
assert_eq!(
client2
- .request(client2_conn_id, proto::Ping { id: 2 },)
+ .request(client2_conn_id, proto::Ping {},)
.await
.unwrap(),
- proto::Pong { id: 2 }
+ proto::Ack {}
);
assert_eq!(
@@ -438,13 +438,7 @@ mod tests {
let envelope = envelope.into_any();
if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
let receipt = envelope.receipt();
- peer.respond(
- receipt,
- proto::Pong {
- id: envelope.payload.id,
- },
- )
- .await?
+ peer.respond(receipt, proto::Ack {}).await?
} else if let Some(envelope) =
envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
{
@@ -492,7 +486,7 @@ mod tests {
#[test]
fn test_disconnect() {
smol::block_on(async move {
- let (client_conn, mut server_conn) = test::Channel::bidirectional();
+ let (client_conn, mut server_conn, _) = Conn::in_memory();
let client = Peer::new();
let (connection_id, io_handler, mut incoming) =
@@ -516,18 +510,17 @@ mod tests {
io_ended_rx.recv().await;
messages_ended_rx.recv().await;
- assert!(
- futures::SinkExt::send(&mut server_conn, WebSocketMessage::Binary(vec![]))
- .await
- .is_err()
- );
+ assert!(server_conn
+ .send(WebSocketMessage::Binary(vec![]))
+ .await
+ .is_err());
});
}
#[test]
fn test_io_error() {
smol::block_on(async move {
- let (client_conn, server_conn) = test::Channel::bidirectional();
+ let (client_conn, server_conn, _) = Conn::in_memory();
drop(server_conn);
let client = Peer::new();
@@ -537,7 +530,7 @@ mod tests {
smol::spawn(async move { incoming.next().await }).detach();
let err = client
- .request(connection_id, proto::Ping { id: 42 })
+ .request(connection_id, proto::Ping {})
.await
.unwrap_err();
assert_eq!(err.to_string(), "connection was closed");
@@ -120,6 +120,7 @@ macro_rules! entity_messages {
}
messages!(
+ Ack,
AddPeer,
BufferSaved,
ChannelMessageSent,
@@ -140,7 +141,6 @@ messages!(
OpenWorktree,
OpenWorktreeResponse,
Ping,
- Pong,
RemovePeer,
SaveBuffer,
SendChannelMessage,
@@ -157,8 +157,9 @@ request_messages!(
(JoinChannel, JoinChannelResponse),
(OpenBuffer, OpenBufferResponse),
(OpenWorktree, OpenWorktreeResponse),
- (Ping, Pong),
+ (Ping, Ack),
(SaveBuffer, BufferSaved),
+ (UpdateBuffer, Ack),
(ShareWorktree, ShareWorktreeResponse),
(SendChannelMessage, SendChannelMessageResponse),
(GetChannelMessages, GetChannelMessagesResponse),
@@ -247,30 +248,3 @@ impl From<SystemTime> for Timestamp {
}
}
}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use crate::test;
-
- #[test]
- fn test_round_trip_message() {
- smol::block_on(async {
- let stream = test::Channel::new();
- let message1 = Ping { id: 5 }.into_envelope(3, None, None);
- let message2 = OpenBuffer {
- worktree_id: 0,
- path: "some/path".to_string(),
- }
- .into_envelope(5, None, None);
-
- let mut message_stream = MessageStream::new(stream);
- message_stream.write_message(&message1).await.unwrap();
- message_stream.write_message(&message2).await.unwrap();
- let decoded_message1 = message_stream.read_message().await.unwrap();
- let decoded_message2 = message_stream.read_message().await.unwrap();
- assert_eq!(decoded_message1, message1);
- assert_eq!(decoded_message2, message2);
- });
- }
-}
@@ -1,64 +0,0 @@
-use async_tungstenite::tungstenite::{Error as WebSocketError, Message as WebSocketMessage};
-use std::{
- io,
- pin::Pin,
- task::{Context, Poll},
-};
-
-pub struct Channel {
- tx: futures::channel::mpsc::UnboundedSender<WebSocketMessage>,
- rx: futures::channel::mpsc::UnboundedReceiver<WebSocketMessage>,
-}
-
-impl Channel {
- pub fn new() -> Self {
- let (tx, rx) = futures::channel::mpsc::unbounded();
- Self { tx, rx }
- }
-
- pub fn bidirectional() -> (Self, Self) {
- let (a_tx, a_rx) = futures::channel::mpsc::unbounded();
- let (b_tx, b_rx) = futures::channel::mpsc::unbounded();
- let a = Self { tx: a_tx, rx: b_rx };
- let b = Self { tx: b_tx, rx: a_rx };
- (a, b)
- }
-}
-
-impl futures::Sink<WebSocketMessage> for Channel {
- type Error = WebSocketError;
-
- fn poll_ready(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
- Pin::new(&mut self.tx)
- .poll_ready(cx)
- .map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
- }
-
- fn start_send(mut self: Pin<&mut Self>, item: WebSocketMessage) -> Result<(), Self::Error> {
- Pin::new(&mut self.tx)
- .start_send(item)
- .map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
- }
-
- fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
- Pin::new(&mut self.tx)
- .poll_flush(cx)
- .map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
- }
-
- fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
- Pin::new(&mut self.tx)
- .poll_close(cx)
- .map_err(|err| io::Error::new(io::ErrorKind::Other, err).into())
- }
-}
-
-impl futures::Stream for Channel {
- type Item = Result<WebSocketMessage, WebSocketError>;
-
- fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
- Pin::new(&mut self.rx)
- .poll_next(cx)
- .map(|i| i.map(|i| Ok(i)))
- }
-}