@@ -7,12 +7,14 @@ use super::{
};
use anyhow::anyhow;
use async_io::Timer;
-use async_std::task;
+use async_std::{
+ sync::{RwLock, RwLockReadGuard, RwLockWriteGuard},
+ task,
+};
use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
use collections::{HashMap, HashSet};
use futures::{channel::mpsc, future::BoxFuture, FutureExt, SinkExt, StreamExt};
use log::{as_debug, as_display};
-use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
use rpc::{
proto::{self, AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage},
Connection, ConnectionId, Peer, TypedEnvelope,
@@ -21,6 +23,9 @@ use sha1::{Digest as _, Sha1};
use std::{
any::TypeId,
future::Future,
+ marker::PhantomData,
+ ops::{Deref, DerefMut},
+ rc::Rc,
sync::Arc,
time::{Duration, Instant},
};
@@ -31,6 +36,7 @@ use tide::{
Request, Response,
};
use time::OffsetDateTime;
+use util::ResultExt;
type MessageHandler = Box<
dyn Send
@@ -58,6 +64,16 @@ pub struct RealExecutor;
const MESSAGE_COUNT_PER_PAGE: usize = 100;
const MAX_MESSAGE_LEN: usize = 1024;
+struct StoreReadGuard<'a> {
+ guard: RwLockReadGuard<'a, Store>,
+ _not_send: PhantomData<Rc<()>>,
+}
+
+struct StoreWriteGuard<'a> {
+ guard: RwLockWriteGuard<'a, Store>,
+ _not_send: PhantomData<Rc<()>>,
+}
+
impl Server {
pub fn new(
app_state: Arc<AppState>,
@@ -78,7 +94,7 @@ impl Server {
.add_message_handler(Server::unregister_project)
.add_request_handler(Server::share_project)
.add_message_handler(Server::unshare_project)
- .add_request_handler(Server::join_project)
+ .add_sync_request_handler(Server::join_project)
.add_message_handler(Server::leave_project)
.add_request_handler(Server::register_worktree)
.add_message_handler(Server::unregister_worktree)
@@ -170,6 +186,42 @@ impl Server {
})
}
+ /// Handle a request while holding a lock to the store. This is useful when we're registering
+ /// a connection but we want to respond on the connection before anybody else can send on it.
+ fn add_sync_request_handler<F, M>(&mut self, handler: F) -> &mut Self
+ where
+ F: 'static
+ + Send
+ + Sync
+ + Fn(Arc<Self>, &mut Store, TypedEnvelope<M>) -> tide::Result<M::Response>,
+ M: RequestMessage,
+ {
+ let handler = Arc::new(handler);
+ self.add_message_handler(move |server, envelope| {
+ let receipt = envelope.receipt();
+ let handler = handler.clone();
+ async move {
+ let mut store = server.store.write().await;
+ let response = (handler)(server.clone(), &mut *store, envelope);
+ match response {
+ Ok(response) => {
+ server.peer.respond(receipt, response)?;
+ Ok(())
+ }
+ Err(error) => {
+ server.peer.respond_with_error(
+ receipt,
+ proto::Error {
+ message: error.to_string(),
+ },
+ )?;
+ Err(error)
+ }
+ }
+ }
+ })
+ }
+
pub fn handle_connection<E: Executor>(
self: &Arc<Self>,
connection: Connection,
@@ -197,9 +249,10 @@ impl Server {
let _ = send_connection_id.send(connection_id).await;
}
- this.state_mut().add_connection(connection_id, user_id);
- if let Err(err) = this.update_contacts_for_users(&[user_id]) {
- log::error!("error updating contacts for {:?}: {}", user_id, err);
+ {
+ let mut state = this.state_mut().await;
+ state.add_connection(connection_id, user_id);
+ this.update_contacts_for_users(&*state, &[user_id]);
}
let handle_io = handle_io.fuse();
@@ -257,7 +310,8 @@ impl Server {
async fn sign_out(self: &mut Arc<Self>, connection_id: ConnectionId) -> tide::Result<()> {
self.peer.disconnect(connection_id);
- let removed_connection = self.state_mut().remove_connection(connection_id)?;
+ let mut state = self.state_mut().await;
+ let removed_connection = state.remove_connection(connection_id)?;
for (project_id, project) in removed_connection.hosted_projects {
if let Some(share) = project.share {
@@ -268,7 +322,7 @@ impl Server {
self.peer
.send(conn_id, proto::UnshareProject { project_id })
},
- )?;
+ );
}
}
@@ -281,10 +335,10 @@ impl Server {
peer_id: connection_id.0,
},
)
- })?;
+ });
}
- self.update_contacts_for_users(removed_connection.contact_ids.iter())?;
+ self.update_contacts_for_users(&*state, removed_connection.contact_ids.iter());
Ok(())
}
@@ -293,11 +347,11 @@ impl Server {
}
async fn register_project(
- mut self: Arc<Server>,
+ self: Arc<Server>,
request: TypedEnvelope<proto::RegisterProject>,
) -> tide::Result<proto::RegisterProjectResponse> {
let project_id = {
- let mut state = self.state_mut();
+ let mut state = self.state_mut().await;
let user_id = state.user_id_for_connection(request.sender_id)?;
state.register_project(request.sender_id, user_id)
};
@@ -305,51 +359,49 @@ impl Server {
}
async fn unregister_project(
- mut self: Arc<Server>,
+ self: Arc<Server>,
request: TypedEnvelope<proto::UnregisterProject>,
) -> tide::Result<()> {
- let project = self
- .state_mut()
- .unregister_project(request.payload.project_id, request.sender_id)?;
- self.update_contacts_for_users(project.authorized_user_ids().iter())?;
+ let mut state = self.state_mut().await;
+ let project = state.unregister_project(request.payload.project_id, request.sender_id)?;
+ self.update_contacts_for_users(&*state, &project.authorized_user_ids());
Ok(())
}
async fn share_project(
- mut self: Arc<Server>,
+ self: Arc<Server>,
request: TypedEnvelope<proto::ShareProject>,
) -> tide::Result<proto::Ack> {
self.state_mut()
+ .await
.share_project(request.payload.project_id, request.sender_id);
Ok(proto::Ack {})
}
async fn unshare_project(
- mut self: Arc<Server>,
+ self: Arc<Server>,
request: TypedEnvelope<proto::UnshareProject>,
) -> tide::Result<()> {
let project_id = request.payload.project_id;
- let project = self
- .state_mut()
- .unshare_project(project_id, request.sender_id)?;
-
+ let mut state = self.state_mut().await;
+ let project = state.unshare_project(project_id, request.sender_id)?;
broadcast(request.sender_id, project.connection_ids, |conn_id| {
self.peer
.send(conn_id, proto::UnshareProject { project_id })
- })?;
- self.update_contacts_for_users(&project.authorized_user_ids)?;
+ });
+ self.update_contacts_for_users(&mut *state, &project.authorized_user_ids);
Ok(())
}
- async fn join_project(
- mut self: Arc<Server>,
+ fn join_project(
+ self: Arc<Server>,
+ state: &mut Store,
request: TypedEnvelope<proto::JoinProject>,
) -> tide::Result<proto::JoinProjectResponse> {
let project_id = request.payload.project_id;
- let user_id = self.state().user_id_for_connection(request.sender_id)?;
- let (response, connection_ids, contact_user_ids) = self
- .state_mut()
+ let user_id = state.user_id_for_connection(request.sender_id)?;
+ let (response, connection_ids, contact_user_ids) = state
.join_project(request.sender_id, user_id, project_id)
.and_then(|joined| {
let share = joined.project.share()?;
@@ -410,19 +462,19 @@ impl Server {
}),
},
)
- })?;
- self.update_contacts_for_users(&contact_user_ids)?;
+ });
+ self.update_contacts_for_users(state, &contact_user_ids);
Ok(response)
}
async fn leave_project(
- mut self: Arc<Server>,
+ self: Arc<Server>,
request: TypedEnvelope<proto::LeaveProject>,
) -> tide::Result<()> {
let sender_id = request.sender_id;
let project_id = request.payload.project_id;
- let worktree = self.state_mut().leave_project(sender_id, project_id)?;
-
+ let mut state = self.state_mut().await;
+ let worktree = state.leave_project(sender_id, project_id)?;
broadcast(sender_id, worktree.connection_ids, |conn_id| {
self.peer.send(
conn_id,
@@ -431,60 +483,57 @@ impl Server {
peer_id: sender_id.0,
},
)
- })?;
- self.update_contacts_for_users(&worktree.authorized_user_ids)?;
-
+ });
+ self.update_contacts_for_users(&*state, &worktree.authorized_user_ids);
Ok(())
}
async fn register_worktree(
- mut self: Arc<Server>,
+ self: Arc<Server>,
request: TypedEnvelope<proto::RegisterWorktree>,
) -> tide::Result<proto::Ack> {
- let host_user_id = self.state().user_id_for_connection(request.sender_id)?;
-
let mut contact_user_ids = HashSet::default();
- contact_user_ids.insert(host_user_id);
for github_login in &request.payload.authorized_logins {
let contact_user_id = self.app_state.db.create_user(github_login, false).await?;
contact_user_ids.insert(contact_user_id);
}
+ let mut state = self.state_mut().await;
+ let host_user_id = state.user_id_for_connection(request.sender_id)?;
+ contact_user_ids.insert(host_user_id);
+
let contact_user_ids = contact_user_ids.into_iter().collect::<Vec<_>>();
- let guest_connection_ids;
- {
- let mut state = self.state_mut();
- guest_connection_ids = state
- .read_project(request.payload.project_id, request.sender_id)?
- .guest_connection_ids();
- state.register_worktree(
- request.payload.project_id,
- request.payload.worktree_id,
- request.sender_id,
- Worktree {
- authorized_user_ids: contact_user_ids.clone(),
- root_name: request.payload.root_name.clone(),
- visible: request.payload.visible,
- },
- )?;
- }
+ let guest_connection_ids = state
+ .read_project(request.payload.project_id, request.sender_id)?
+ .guest_connection_ids();
+ state.register_worktree(
+ request.payload.project_id,
+ request.payload.worktree_id,
+ request.sender_id,
+ Worktree {
+ authorized_user_ids: contact_user_ids.clone(),
+ root_name: request.payload.root_name.clone(),
+ visible: request.payload.visible,
+ },
+ )?;
+
broadcast(request.sender_id, guest_connection_ids, |connection_id| {
self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone())
- })?;
- self.update_contacts_for_users(&contact_user_ids)?;
+ });
+ self.update_contacts_for_users(&*state, &contact_user_ids);
Ok(proto::Ack {})
}
async fn unregister_worktree(
- mut self: Arc<Server>,
+ self: Arc<Server>,
request: TypedEnvelope<proto::UnregisterWorktree>,
) -> tide::Result<()> {
let project_id = request.payload.project_id;
let worktree_id = request.payload.worktree_id;
+ let mut state = self.state_mut().await;
let (worktree, guest_connection_ids) =
- self.state_mut()
- .unregister_worktree(project_id, worktree_id, request.sender_id)?;
+ state.unregister_worktree(project_id, worktree_id, request.sender_id)?;
broadcast(request.sender_id, guest_connection_ids, |conn_id| {
self.peer.send(
conn_id,
@@ -493,16 +542,16 @@ impl Server {
worktree_id,
},
)
- })?;
- self.update_contacts_for_users(&worktree.authorized_user_ids)?;
+ });
+ self.update_contacts_for_users(&*state, &worktree.authorized_user_ids);
Ok(())
}
async fn update_worktree(
- mut self: Arc<Server>,
+ self: Arc<Server>,
request: TypedEnvelope<proto::UpdateWorktree>,
) -> tide::Result<proto::Ack> {
- let connection_ids = self.state_mut().update_worktree(
+ let connection_ids = self.state_mut().await.update_worktree(
request.sender_id,
request.payload.project_id,
request.payload.worktree_id,
@@ -513,13 +562,13 @@ impl Server {
broadcast(request.sender_id, connection_ids, |connection_id| {
self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone())
- })?;
+ });
Ok(proto::Ack {})
}
async fn update_diagnostic_summary(
- mut self: Arc<Server>,
+ self: Arc<Server>,
request: TypedEnvelope<proto::UpdateDiagnosticSummary>,
) -> tide::Result<()> {
let summary = request
@@ -527,7 +576,7 @@ impl Server {
.summary
.clone()
.ok_or_else(|| anyhow!("invalid summary"))?;
- let receiver_ids = self.state_mut().update_diagnostic_summary(
+ let receiver_ids = self.state_mut().await.update_diagnostic_summary(
request.payload.project_id,
request.payload.worktree_id,
request.sender_id,
@@ -537,15 +586,15 @@ impl Server {
broadcast(request.sender_id, receiver_ids, |connection_id| {
self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone())
- })?;
+ });
Ok(())
}
async fn start_language_server(
- mut self: Arc<Server>,
+ self: Arc<Server>,
request: TypedEnvelope<proto::StartLanguageServer>,
) -> tide::Result<()> {
- let receiver_ids = self.state_mut().start_language_server(
+ let receiver_ids = self.state_mut().await.start_language_server(
request.payload.project_id,
request.sender_id,
request
@@ -557,7 +606,7 @@ impl Server {
broadcast(request.sender_id, receiver_ids, |connection_id| {
self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone())
- })?;
+ });
Ok(())
}
@@ -567,11 +616,12 @@ impl Server {
) -> tide::Result<()> {
let receiver_ids = self
.state()
+ .await
.project_connection_ids(request.payload.project_id, request.sender_id)?;
broadcast(request.sender_id, receiver_ids, |connection_id| {
self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone())
- })?;
+ });
Ok(())
}
@@ -584,6 +634,7 @@ impl Server {
{
let host_connection_id = self
.state()
+ .await
.read_project(request.payload.remote_entity_id(), request.sender_id)?
.host_connection_id;
Ok(self
@@ -596,24 +647,25 @@ impl Server {
self: Arc<Server>,
request: TypedEnvelope<proto::SaveBuffer>,
) -> tide::Result<proto::BufferSaved> {
- let host;
- let mut guests;
- {
- let state = self.state();
- let project = state.read_project(request.payload.project_id, request.sender_id)?;
- host = project.host_connection_id;
- guests = project.guest_connection_ids()
- }
-
+ let host = self
+ .state()
+ .await
+ .read_project(request.payload.project_id, request.sender_id)?
+ .host_connection_id;
let response = self
.peer
.forward_request(request.sender_id, host, request.payload.clone())
.await?;
+ let mut guests = self
+ .state()
+ .await
+ .read_project(request.payload.project_id, request.sender_id)?
+ .connection_ids();
guests.retain(|guest_connection_id| *guest_connection_id != request.sender_id);
broadcast(host, guests, |conn_id| {
self.peer.forward_send(host, conn_id, response.clone())
- })?;
+ });
Ok(response)
}
@@ -624,11 +676,12 @@ impl Server {
) -> tide::Result<proto::Ack> {
let receiver_ids = self
.state()
+ .await
.project_connection_ids(request.payload.project_id, request.sender_id)?;
broadcast(request.sender_id, receiver_ids, |connection_id| {
self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone())
- })?;
+ });
Ok(proto::Ack {})
}
@@ -638,11 +691,12 @@ impl Server {
) -> tide::Result<()> {
let receiver_ids = self
.state()
+ .await
.project_connection_ids(request.payload.project_id, request.sender_id)?;
broadcast(request.sender_id, receiver_ids, |connection_id| {
self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone())
- })?;
+ });
Ok(())
}
@@ -652,11 +706,12 @@ impl Server {
) -> tide::Result<()> {
let receiver_ids = self
.state()
+ .await
.project_connection_ids(request.payload.project_id, request.sender_id)?;
broadcast(request.sender_id, receiver_ids, |connection_id| {
self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone())
- })?;
+ });
Ok(())
}
@@ -666,11 +721,12 @@ impl Server {
) -> tide::Result<()> {
let receiver_ids = self
.state()
+ .await
.project_connection_ids(request.payload.project_id, request.sender_id)?;
broadcast(request.sender_id, receiver_ids, |connection_id| {
self.peer
.forward_send(request.sender_id, connection_id, request.payload.clone())
- })?;
+ });
Ok(())
}
@@ -682,6 +738,7 @@ impl Server {
let follower_id = request.sender_id;
if !self
.state()
+ .await
.project_connection_ids(request.payload.project_id, follower_id)?
.contains(&leader_id)
{
@@ -704,6 +761,7 @@ impl Server {
let leader_id = ConnectionId(request.payload.leader_id);
if !self
.state()
+ .await
.project_connection_ids(request.payload.project_id, request.sender_id)?
.contains(&leader_id)
{
@@ -720,6 +778,7 @@ impl Server {
) -> tide::Result<()> {
let connection_ids = self
.state()
+ .await
.project_connection_ids(request.payload.project_id, request.sender_id)?;
let leader_id = request
.payload
@@ -744,7 +803,10 @@ impl Server {
self: Arc<Server>,
request: TypedEnvelope<proto::GetChannels>,
) -> tide::Result<proto::GetChannelsResponse> {
- let user_id = self.state().user_id_for_connection(request.sender_id)?;
+ let user_id = self
+ .state()
+ .await
+ .user_id_for_connection(request.sender_id)?;
let channels = self.app_state.db.get_accessible_channels(user_id).await?;
Ok(proto::GetChannelsResponse {
channels: channels
@@ -783,32 +845,33 @@ impl Server {
}
fn update_contacts_for_users<'a>(
- self: &Arc<Server>,
+ self: &Arc<Self>,
+ state: &Store,
user_ids: impl IntoIterator<Item = &'a UserId>,
- ) -> anyhow::Result<()> {
- let mut result = Ok(());
- let state = self.state();
+ ) {
for user_id in user_ids {
let contacts = state.contacts_for_user(*user_id);
for connection_id in state.connection_ids_for_user(*user_id) {
- if let Err(error) = self.peer.send(
- connection_id,
- proto::UpdateContacts {
- contacts: contacts.clone(),
- },
- ) {
- result = Err(error);
- }
+ self.peer
+ .send(
+ connection_id,
+ proto::UpdateContacts {
+ contacts: contacts.clone(),
+ },
+ )
+ .log_err();
}
}
- result
}
async fn join_channel(
- mut self: Arc<Self>,
+ self: Arc<Self>,
request: TypedEnvelope<proto::JoinChannel>,
) -> tide::Result<proto::JoinChannelResponse> {
- let user_id = self.state().user_id_for_connection(request.sender_id)?;
+ let user_id = self
+ .state()
+ .await
+ .user_id_for_connection(request.sender_id)?;
let channel_id = ChannelId::from_proto(request.payload.channel_id);
if !self
.app_state
@@ -819,7 +882,9 @@ impl Server {
Err(anyhow!("access denied"))?;
}
- self.state_mut().join_channel(request.sender_id, channel_id);
+ self.state_mut()
+ .await
+ .join_channel(request.sender_id, channel_id);
let messages = self
.app_state
.db
@@ -841,10 +906,13 @@ impl Server {
}
async fn leave_channel(
- mut self: Arc<Self>,
+ self: Arc<Self>,
request: TypedEnvelope<proto::LeaveChannel>,
) -> tide::Result<()> {
- let user_id = self.state().user_id_for_connection(request.sender_id)?;
+ let user_id = self
+ .state()
+ .await
+ .user_id_for_connection(request.sender_id)?;
let channel_id = ChannelId::from_proto(request.payload.channel_id);
if !self
.app_state
@@ -856,6 +924,7 @@ impl Server {
}
self.state_mut()
+ .await
.leave_channel(request.sender_id, channel_id);
Ok(())
@@ -869,7 +938,7 @@ impl Server {
let user_id;
let connection_ids;
{
- let state = self.state();
+ let state = self.state().await;
user_id = state.user_id_for_connection(request.sender_id)?;
connection_ids = state.channel_connection_ids(channel_id)?;
}
@@ -910,7 +979,7 @@ impl Server {
message: Some(message.clone()),
},
)
- })?;
+ });
Ok(proto::SendChannelMessageResponse {
message: Some(message),
})
@@ -920,7 +989,10 @@ impl Server {
self: Arc<Self>,
request: TypedEnvelope<proto::GetChannelMessages>,
) -> tide::Result<proto::GetChannelMessagesResponse> {
- let user_id = self.state().user_id_for_connection(request.sender_id)?;
+ let user_id = self
+ .state()
+ .await
+ .user_id_for_connection(request.sender_id)?;
let channel_id = ChannelId::from_proto(request.payload.channel_id);
if !self
.app_state
@@ -956,12 +1028,57 @@ impl Server {
})
}
- fn state<'a>(self: &'a Arc<Self>) -> RwLockReadGuard<'a, Store> {
- self.store.read()
+ async fn state<'a>(self: &'a Arc<Self>) -> StoreReadGuard<'a> {
+ #[cfg(test)]
+ async_std::task::yield_now().await;
+ let guard = self.store.read().await;
+ #[cfg(test)]
+ async_std::task::yield_now().await;
+ StoreReadGuard {
+ guard,
+ _not_send: PhantomData,
+ }
+ }
+
+ async fn state_mut<'a>(self: &'a Arc<Self>) -> StoreWriteGuard<'a> {
+ #[cfg(test)]
+ async_std::task::yield_now().await;
+ let guard = self.store.write().await;
+ #[cfg(test)]
+ async_std::task::yield_now().await;
+ StoreWriteGuard {
+ guard,
+ _not_send: PhantomData,
+ }
+ }
+}
+
+impl<'a> Deref for StoreReadGuard<'a> {
+ type Target = Store;
+
+ fn deref(&self) -> &Self::Target {
+ &*self.guard
+ }
+}
+
+impl<'a> Deref for StoreWriteGuard<'a> {
+ type Target = Store;
+
+ fn deref(&self) -> &Self::Target {
+ &*self.guard
}
+}
- fn state_mut<'a>(self: &'a mut Arc<Self>) -> RwLockWriteGuard<'a, Store> {
- self.store.write()
+impl<'a> DerefMut for StoreWriteGuard<'a> {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ &mut *self.guard
+ }
+}
+
+impl<'a> Drop for StoreWriteGuard<'a> {
+ fn drop(&mut self) {
+ #[cfg(test)]
+ self.check_invariants();
}
}
@@ -977,25 +1094,15 @@ impl Executor for RealExecutor {
}
}
-fn broadcast<F>(
- sender_id: ConnectionId,
- receiver_ids: Vec<ConnectionId>,
- mut f: F,
-) -> anyhow::Result<()>
+fn broadcast<F>(sender_id: ConnectionId, receiver_ids: Vec<ConnectionId>, mut f: F)
where
F: FnMut(ConnectionId) -> anyhow::Result<()>,
{
- let mut result = Ok(());
for receiver_id in receiver_ids {
if receiver_id != sender_id {
- if let Err(error) = f(receiver_id) {
- if result.is_ok() {
- result = Err(error);
- }
- }
+ f(receiver_id).log_err();
}
}
- result
}
pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
@@ -1087,7 +1194,11 @@ mod tests {
self, ConfirmCodeAction, ConfirmCompletion, ConfirmRename, Editor, Input, Redo, Rename,
ToOffset, ToggleCodeActions, Undo,
};
- use gpui::{executor, geometry::vector::vec2f, ModelHandle, TestAppContext, ViewHandle};
+ use gpui::{
+ executor::{self, Deterministic},
+ geometry::vector::vec2f,
+ ModelHandle, TestAppContext, ViewHandle,
+ };
use language::{
range_to_lsp, tree_sitter_rust, Diagnostic, DiagnosticEntry, FakeLspAdapter, Language,
LanguageConfig, LanguageRegistry, OffsetRangeExt, Point, Rope,
@@ -1106,7 +1217,6 @@ mod tests {
use settings::Settings;
use sqlx::types::time::OffsetDateTime;
use std::{
- cell::Cell,
env,
ops::Deref,
path::{Path, PathBuf},
@@ -1118,7 +1228,6 @@ mod tests {
time::Duration,
};
use theme::ThemeRegistry;
- use util::TryFutureExt;
use workspace::{Item, SplitDirection, ToggleFollow, Workspace, WorkspaceParams};
#[cfg(test)]
@@ -4975,11 +5084,17 @@ mod tests {
}
#[gpui::test(iterations = 100)]
- async fn test_random_collaboration(cx: &mut TestAppContext, rng: StdRng) {
+ async fn test_random_collaboration(
+ cx: &mut TestAppContext,
+ deterministic: Arc<Deterministic>,
+ rng: StdRng,
+ ) {
cx.foreground().forbid_parking();
let max_peers = env::var("MAX_PEERS")
.map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
.unwrap_or(5);
+ assert!(max_peers <= 5);
+
let max_operations = env::var("OPERATIONS")
.map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
.unwrap_or(10);
@@ -4993,23 +5108,23 @@ mod tests {
fs.insert_tree(
"/_collab",
json!({
- ".zed.toml": r#"collaborators = ["guest-1", "guest-2", "guest-3", "guest-4", "guest-5"]"#
+ ".zed.toml": r#"collaborators = ["guest-1", "guest-2", "guest-3", "guest-4"]"#
}),
)
.await;
- let operations = Rc::new(Cell::new(0));
let mut server = TestServer::start(cx.foreground(), cx.background()).await;
let mut clients = Vec::new();
let mut user_ids = Vec::new();
+ let mut op_start_signals = Vec::new();
let files = Arc::new(Mutex::new(Vec::new()));
let mut next_entity_id = 100000;
let mut host_cx = TestAppContext::new(
cx.foreground_platform(),
cx.platform(),
- cx.foreground(),
- cx.background(),
+ deterministic.build_foreground(next_entity_id),
+ deterministic.build_background(),
cx.font_cache(),
cx.leak_detector(),
next_entity_id,
@@ -5169,77 +5284,53 @@ mod tests {
});
host_language_registry.add(Arc::new(language));
- let host_disconnected = Rc::new(AtomicBool::new(false));
+ let op_start_signal = futures::channel::mpsc::unbounded();
user_ids.push(host.current_user_id(&host_cx));
- clients.push(cx.foreground().spawn(host.simulate_host(
+ op_start_signals.push(op_start_signal.0);
+ clients.push(host_cx.foreground().spawn(host.simulate_host(
host_project,
files,
- operations.clone(),
- max_operations,
+ op_start_signal.1,
rng.clone(),
host_cx,
)));
- while operations.get() < max_operations {
- cx.background().simulate_random_delay().await;
- if clients.len() >= max_peers {
- break;
- } else if rng.lock().gen_bool(0.05) {
- operations.set(operations.get() + 1);
-
- let guest_id = clients.len();
- log::info!("Adding guest {}", guest_id);
- next_entity_id += 100000;
- let mut guest_cx = TestAppContext::new(
- cx.foreground_platform(),
- cx.platform(),
- cx.foreground(),
- cx.background(),
- cx.font_cache(),
- cx.leak_detector(),
- next_entity_id,
- );
- let guest = server
- .create_client(&mut guest_cx, &format!("guest-{}", guest_id))
- .await;
- let guest_project = Project::remote(
- host_project_id,
- guest.client.clone(),
- guest.user_store.clone(),
- guest_lang_registry.clone(),
- FakeFs::new(cx.background()),
- &mut guest_cx.to_async(),
- )
- .await
- .unwrap();
- user_ids.push(guest.current_user_id(&guest_cx));
- clients.push(cx.foreground().spawn(guest.simulate_guest(
- guest_id,
- guest_project,
- operations.clone(),
- max_operations,
- rng.clone(),
- host_disconnected.clone(),
- guest_cx,
- )));
-
- log::info!("Guest {} added", guest_id);
- } else if rng.lock().gen_bool(0.05) {
- host_disconnected.store(true, SeqCst);
+ let disconnect_host_at = if rng.lock().gen_bool(0.2) {
+ rng.lock().gen_range(0..max_operations)
+ } else {
+ max_operations
+ };
+ let mut available_guests = vec![
+ "guest-1".to_string(),
+ "guest-2".to_string(),
+ "guest-3".to_string(),
+ "guest-4".to_string(),
+ ];
+ let mut operations = 0;
+ while operations < max_operations {
+ if operations == disconnect_host_at {
server.disconnect_client(user_ids[0]);
cx.foreground().advance_clock(RECEIVE_TIMEOUT);
+ drop(op_start_signals);
let mut clients = futures::future::join_all(clients).await;
cx.foreground().run_until_parked();
- let (host, mut host_cx) = clients.remove(0);
+ let (host, mut host_cx, host_err) = clients.remove(0);
+ if let Some(host_err) = host_err {
+ log::error!("host error - {}", host_err);
+ }
host.project
.as_ref()
.unwrap()
.read_with(&host_cx, |project, _| assert!(!project.is_shared()));
- for (guest, mut guest_cx) in clients {
+ for (guest, mut guest_cx, guest_err) in clients {
+ if let Some(guest_err) = guest_err {
+ log::error!("{} error - {}", guest.username, guest_err);
+ }
let contacts = server
.store
.read()
+ .await
.contacts_for_user(guest.current_user_id(&guest_cx));
assert!(!contacts
.iter()
@@ -5256,12 +5347,113 @@ mod tests {
return;
}
+
+ let distribution = rng.lock().gen_range(0..100);
+ match distribution {
+ 0..=19 if !available_guests.is_empty() => {
+ let guest_ix = rng.lock().gen_range(0..available_guests.len());
+ let guest_username = available_guests.remove(guest_ix);
+ log::info!("Adding new connection for {}", guest_username);
+ next_entity_id += 100000;
+ let mut guest_cx = TestAppContext::new(
+ cx.foreground_platform(),
+ cx.platform(),
+ deterministic.build_foreground(next_entity_id),
+ deterministic.build_background(),
+ cx.font_cache(),
+ cx.leak_detector(),
+ next_entity_id,
+ );
+ let guest = server.create_client(&mut guest_cx, &guest_username).await;
+ let guest_project = Project::remote(
+ host_project_id,
+ guest.client.clone(),
+ guest.user_store.clone(),
+ guest_lang_registry.clone(),
+ FakeFs::new(cx.background()),
+ &mut guest_cx.to_async(),
+ )
+ .await
+ .unwrap();
+ let op_start_signal = futures::channel::mpsc::unbounded();
+ user_ids.push(guest.current_user_id(&guest_cx));
+ op_start_signals.push(op_start_signal.0);
+ clients.push(guest_cx.foreground().spawn(guest.simulate_guest(
+ guest_username.clone(),
+ guest_project,
+ op_start_signal.1,
+ rng.clone(),
+ guest_cx,
+ )));
+
+ log::info!("Added connection for {}", guest_username);
+ operations += 1;
+ }
+ 20..=29 if clients.len() > 1 => {
+ log::info!("Removing guest");
+ let guest_ix = rng.lock().gen_range(1..clients.len());
+ let removed_guest_id = user_ids.remove(guest_ix);
+ let guest = clients.remove(guest_ix);
+ op_start_signals.remove(guest_ix);
+ server.disconnect_client(removed_guest_id);
+ cx.foreground().advance_clock(RECEIVE_TIMEOUT);
+ let (guest, mut guest_cx, guest_err) = guest.await;
+ if let Some(guest_err) = guest_err {
+ log::error!("{} error - {}", guest.username, guest_err);
+ }
+ guest
+ .project
+ .as_ref()
+ .unwrap()
+ .read_with(&guest_cx, |project, _| assert!(project.is_read_only()));
+ for user_id in &user_ids {
+ for contact in server.store.read().await.contacts_for_user(*user_id) {
+ assert_ne!(
+ contact.user_id, removed_guest_id.0 as u64,
+ "removed guest is still a contact of another peer"
+ );
+ for project in contact.projects {
+ for project_guest_id in project.guests {
+ assert_ne!(
+ project_guest_id, removed_guest_id.0 as u64,
+ "removed guest appears as still participating on a project"
+ );
+ }
+ }
+ }
+ }
+
+ log::info!("{} removed", guest.username);
+ available_guests.push(guest.username.clone());
+ guest_cx.update(|_| drop(guest));
+
+ operations += 1;
+ }
+ _ => {
+ while operations < max_operations && rng.lock().gen_bool(0.7) {
+ op_start_signals
+ .choose(&mut *rng.lock())
+ .unwrap()
+ .unbounded_send(())
+ .unwrap();
+ operations += 1;
+ }
+
+ if rng.lock().gen_bool(0.8) {
+ cx.foreground().run_until_parked();
+ }
+ }
+ }
}
+ drop(op_start_signals);
let mut clients = futures::future::join_all(clients).await;
cx.foreground().run_until_parked();
- let (host_client, mut host_cx) = clients.remove(0);
+ let (host_client, mut host_cx, host_err) = clients.remove(0);
+ if let Some(host_err) = host_err {
+ panic!("host error - {}", host_err);
+ }
let host_project = host_client.project.as_ref().unwrap();
let host_worktree_snapshots = host_project.read_with(&host_cx, |project, cx| {
project