diff --git a/crates/chat_panel/src/chat_panel.rs b/crates/chat_panel/src/chat_panel.rs index b155d9fc3260225cdddae2529a90ce34b16d5f67..800084ff1db1d2c9670368339e9ea4da63ee1c5d 100644 --- a/crates/chat_panel/src/chat_panel.rs +++ b/crates/chat_panel/src/chat_panel.rs @@ -325,17 +325,21 @@ impl ChatPanel { enum SignInPromptLabel {} Align::new( - MouseEventHandler::new::(0, cx, |mouse_state, _| { - Label::new( - "Sign in to use chat".to_string(), - if mouse_state.hovered { - theme.chat_panel.hovered_sign_in_prompt.clone() - } else { - theme.chat_panel.sign_in_prompt.clone() - }, - ) - .boxed() - }) + MouseEventHandler::new::( + cx.view_id(), + cx, + |mouse_state, _| { + Label::new( + "Sign in to use chat".to_string(), + if mouse_state.hovered { + theme.chat_panel.hovered_sign_in_prompt.clone() + } else { + theme.chat_panel.sign_in_prompt.clone() + }, + ) + .boxed() + }, + ) .with_cursor_style(CursorStyle::PointingHand) .on_click(move |cx| { let rpc = rpc.clone(); diff --git a/crates/client/src/channel.rs b/crates/client/src/channel.rs index ab65b4d22830c5ea4d815d6db7f352574c59f013..1b00d4daf697d73f7bda0dea140dc399711a06f1 100644 --- a/crates/client/src/channel.rs +++ b/crates/client/src/channel.rs @@ -4,6 +4,7 @@ use super::{ Client, Status, Subscription, TypedEnvelope, }; use anyhow::{anyhow, Context, Result}; +use futures::lock::Mutex; use gpui::{ AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, WeakModelHandle, }; @@ -40,6 +41,7 @@ pub struct Channel { next_pending_message_id: usize, user_store: ModelHandle, rpc: Arc, + outgoing_messages_lock: Arc>, rng: StdRng, _subscription: Subscription, } @@ -178,14 +180,17 @@ impl Entity for Channel { } impl Channel { + pub fn init(rpc: &Arc) { + rpc.add_entity_message_handler(Self::handle_message_sent); + } + pub fn new( details: ChannelDetails, user_store: ModelHandle, rpc: Arc, cx: &mut ModelContext, ) -> Self { - let _subscription = - rpc.add_entity_message_handler(details.id, cx, Self::handle_message_sent); + let _subscription = rpc.add_model_for_remote_entity(details.id, cx); { let user_store = user_store.clone(); @@ -214,6 +219,7 @@ impl Channel { details, user_store, rpc, + outgoing_messages_lock: Default::default(), messages: Default::default(), loaded_all_messages: false, next_pending_message_id: 0, @@ -259,13 +265,16 @@ impl Channel { ); let user_store = self.user_store.clone(); let rpc = self.rpc.clone(); + let outgoing_messages_lock = self.outgoing_messages_lock.clone(); Ok(cx.spawn(|this, mut cx| async move { + let outgoing_message_guard = outgoing_messages_lock.lock().await; let request = rpc.request(proto::SendChannelMessage { channel_id, body, nonce: Some(nonce.into()), }); let response = request.await?; + drop(outgoing_message_guard); let message = ChannelMessage::from_proto( response.message.ok_or_else(|| anyhow!("invalid message"))?, &user_store, @@ -589,10 +598,14 @@ mod tests { #[gpui::test] async fn test_channel_messages(mut cx: TestAppContext) { + cx.foreground().forbid_parking(); + let user_id = 5; let http_client = FakeHttpClient::new(|_| async move { Ok(Response::new(404)) }); let mut client = Client::new(http_client.clone()); let server = FakeServer::for_client(user_id, &mut client, &cx).await; + + Channel::init(&client); let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http_client, cx)); let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx)); diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 702360f7787a238e119df1da50f071a49fefc6be..157ecf3af67ae0ea8d5e86383b502cac376345a0 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -12,7 +12,10 @@ use async_tungstenite::tungstenite::{ http::{Request, StatusCode}, }; use futures::{future::LocalBoxFuture, FutureExt, StreamExt}; -use gpui::{action, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task}; +use gpui::{ + action, AnyModelHandle, AnyWeakModelHandle, AsyncAppContext, Entity, ModelContext, ModelHandle, + MutableAppContext, Task, +}; use http::HttpClient; use lazy_static::lazy_static; use parking_lot::RwLock; @@ -20,7 +23,7 @@ use postage::watch; use rand::prelude::*; use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage}; use std::{ - any::{type_name, TypeId}, + any::TypeId, collections::HashMap, convert::TryFrom, fmt::Write as _, @@ -124,19 +127,28 @@ pub enum Status { ReconnectionError { next_reconnection: Instant }, } -type ModelHandler = Box< - dyn Send - + Sync - + FnMut(Box, &AsyncAppContext) -> LocalBoxFuture<'static, Result<()>>, ->; - struct ClientState { credentials: Option, status: (watch::Sender, watch::Receiver), entity_id_extractors: HashMap u64>>, - model_handlers: HashMap<(TypeId, Option), Option>, _maintain_connection: Option>, heartbeat_interval: Duration, + + models_by_entity_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakModelHandle>, + models_by_message_type: HashMap, + model_types_by_message_type: HashMap, + message_handlers: HashMap< + TypeId, + Arc< + dyn Send + + Sync + + Fn( + AnyModelHandle, + Box, + AsyncAppContext, + ) -> LocalBoxFuture<'static, Result<()>>, + >, + >, } #[derive(Clone, Debug)] @@ -151,23 +163,43 @@ impl Default for ClientState { credentials: None, status: watch::channel_with(Status::SignedOut), entity_id_extractors: Default::default(), - model_handlers: Default::default(), _maintain_connection: None, heartbeat_interval: Duration::from_secs(5), + models_by_message_type: Default::default(), + models_by_entity_type_and_remote_id: Default::default(), + model_types_by_message_type: Default::default(), + message_handlers: Default::default(), } } } -pub struct Subscription { - client: Weak, - id: (TypeId, Option), +pub enum Subscription { + Entity { + client: Weak, + id: (TypeId, u64), + }, + Message { + client: Weak, + id: TypeId, + }, } impl Drop for Subscription { fn drop(&mut self) { - if let Some(client) = self.client.upgrade() { - let mut state = client.state.write(); - let _ = state.model_handlers.remove(&self.id).unwrap(); + match self { + Subscription::Entity { client, id } => { + if let Some(client) = client.upgrade() { + let mut state = client.state.write(); + let _ = state.models_by_entity_type_and_remote_id.remove(id); + } + } + Subscription::Message { client, id } => { + if let Some(client) = client.upgrade() { + let mut state = client.state.write(); + let _ = state.model_types_by_message_type.remove(id); + let _ = state.message_handlers.remove(id); + } + } } } } @@ -188,6 +220,10 @@ impl Client { }) } + pub fn id(&self) -> usize { + self.id + } + #[cfg(any(test, feature = "test-support"))] pub fn override_authenticate(&mut self, authenticate: F) -> &mut Self where @@ -266,125 +302,118 @@ impl Client { } } - pub fn add_message_handler( + pub fn add_model_for_remote_entity( + self: &Arc, + remote_id: u64, + cx: &mut ModelContext, + ) -> Subscription { + let handle = AnyModelHandle::from(cx.handle()); + let mut state = self.state.write(); + let id = (TypeId::of::(), remote_id); + state + .models_by_entity_type_and_remote_id + .insert(id, handle.downgrade()); + Subscription::Entity { + client: Arc::downgrade(self), + id, + } + } + + pub fn add_message_handler( self: &Arc, - cx: &mut ModelContext, - mut handler: F, + model: ModelHandle, + handler: H, ) -> Subscription where - T: EnvelopedMessage, - M: Entity, - F: 'static + M: EnvelopedMessage, + E: Entity, + H: 'static + Send + Sync - + FnMut(ModelHandle, TypedEnvelope, Arc, AsyncAppContext) -> Fut, - Fut: 'static + Future>, + + Fn(ModelHandle, TypedEnvelope, Arc, AsyncAppContext) -> F, + F: 'static + Future>, { - let subscription_id = (TypeId::of::(), None); + let message_type_id = TypeId::of::(); + let client = self.clone(); let mut state = self.state.write(); - let model = cx.weak_handle(); - let prev_handler = state.model_handlers.insert( - subscription_id, - Some(Box::new(move |envelope, cx| { - if let Some(model) = model.upgrade(cx) { - let envelope = envelope.into_any().downcast::>().unwrap(); - handler(model, *envelope, client.clone(), cx.clone()).boxed_local() - } else { - async move { - Err(anyhow!( - "received message for {:?} but model was dropped", - type_name::() - )) - } - .boxed_local() - } - })), + state + .models_by_message_type + .insert(message_type_id, model.into()); + + let prev_handler = state.message_handlers.insert( + message_type_id, + Arc::new(move |handle, envelope, cx| { + let model = handle.downcast::().unwrap(); + let envelope = envelope.into_any().downcast::>().unwrap(); + handler(model, *envelope, client.clone(), cx).boxed_local() + }), ); if prev_handler.is_some() { panic!("registered handler for the same message twice"); } - Subscription { + Subscription::Message { client: Arc::downgrade(self), - id: subscription_id, + id: message_type_id, } } - pub fn add_entity_message_handler( - self: &Arc, - remote_id: u64, - cx: &mut ModelContext, - mut handler: F, - ) -> Subscription + pub fn add_entity_message_handler(self: &Arc, handler: H) where - T: EntityMessage, - M: Entity, - F: 'static + M: EntityMessage, + E: Entity, + H: 'static + Send + Sync - + FnMut(ModelHandle, TypedEnvelope, Arc, AsyncAppContext) -> Fut, - Fut: 'static + Future>, + + Fn(ModelHandle, TypedEnvelope, Arc, AsyncAppContext) -> F, + F: 'static + Future>, { - let subscription_id = (TypeId::of::(), Some(remote_id)); + let model_type_id = TypeId::of::(); + let message_type_id = TypeId::of::(); + let client = self.clone(); let mut state = self.state.write(); - let model = cx.weak_handle(); + state + .model_types_by_message_type + .insert(message_type_id, model_type_id); state .entity_id_extractors - .entry(subscription_id.0) + .entry(message_type_id) .or_insert_with(|| { Box::new(|envelope| { let envelope = envelope .as_any() - .downcast_ref::>() + .downcast_ref::>() .unwrap(); envelope.payload.remote_entity_id() }) }); - let prev_handler = state.model_handlers.insert( - subscription_id, - Some(Box::new(move |envelope, cx| { - if let Some(model) = model.upgrade(cx) { - let envelope = envelope.into_any().downcast::>().unwrap(); - handler(model, *envelope, client.clone(), cx.clone()).boxed_local() - } else { - async move { - Err(anyhow!( - "received message for {:?} but model was dropped", - type_name::() - )) - } - .boxed_local() - } - })), + + let prev_handler = state.message_handlers.insert( + message_type_id, + Arc::new(move |handle, envelope, cx| { + let model = handle.downcast::().unwrap(); + let envelope = envelope.into_any().downcast::>().unwrap(); + handler(model, *envelope, client.clone(), cx).boxed_local() + }), ); if prev_handler.is_some() { - panic!("registered a handler for the same entity twice") - } - - Subscription { - client: Arc::downgrade(self), - id: subscription_id, + panic!("registered handler for the same message twice"); } } - pub fn add_entity_request_handler( - self: &Arc, - remote_id: u64, - cx: &mut ModelContext, - mut handler: F, - ) -> Subscription + pub fn add_entity_request_handler(self: &Arc, handler: H) where - T: EntityMessage + RequestMessage, - M: Entity, - F: 'static + M: EntityMessage + RequestMessage, + E: Entity, + H: 'static + Send + Sync - + FnMut(ModelHandle, TypedEnvelope, Arc, AsyncAppContext) -> Fut, - Fut: 'static + Future>, + + Fn(ModelHandle, TypedEnvelope, Arc, AsyncAppContext) -> F, + F: 'static + Future>, { - self.add_entity_message_handler(remote_id, cx, move |model, envelope, client, cx| { + self.add_entity_message_handler(move |model, envelope, client, cx| { let receipt = envelope.receipt(); let response = handler(model, envelope, client.clone(), cx); async move { @@ -500,27 +529,45 @@ impl Client { while let Some(message) = incoming.next().await { let mut state = this.state.write(); let payload_type_id = message.payload_type_id(); - let entity_id = if let Some(extract_entity_id) = - state.entity_id_extractors.get(&message.payload_type_id()) - { - Some((extract_entity_id)(message.as_ref())) + let type_name = message.payload_type_name(); + + let model = state + .models_by_message_type + .get(&payload_type_id) + .cloned() + .or_else(|| { + let model_type_id = + *state.model_types_by_message_type.get(&payload_type_id)?; + let entity_id = state + .entity_id_extractors + .get(&message.payload_type_id()) + .map(|extract_entity_id| { + (extract_entity_id)(message.as_ref()) + })?; + let model = state + .models_by_entity_type_and_remote_id + .get(&(model_type_id, entity_id))?; + if let Some(model) = model.upgrade(&cx) { + Some(model) + } else { + state + .models_by_entity_type_and_remote_id + .remove(&(model_type_id, entity_id)); + None + } + }); + + let model = if let Some(model) = model { + model } else { - None + log::info!("unhandled message {}", type_name); + continue; }; - let type_name = message.payload_type_name(); - - let handler_key = (payload_type_id, entity_id); - if let Some(handler) = state.model_handlers.get_mut(&handler_key) { - let mut handler = handler.take().unwrap(); + if let Some(handler) = state.message_handlers.get(&payload_type_id).cloned() + { drop(state); // Avoid deadlocks if the handler interacts with rpc::Client - let future = (handler)(message, &cx); - { - let mut state = this.state.write(); - if state.model_handlers.contains_key(&handler_key) { - state.model_handlers.insert(handler_key, Some(handler)); - } - } + let future = handler(model, message, cx.clone()); let client_id = this.id; log::debug!( @@ -540,7 +587,7 @@ impl Client { } Err(error) => { log::error!( - "error handling rpc message. client_id:{}, name:{}, error:{}", + "error handling message. client_id:{}, name:{}, {}", client_id, type_name, error @@ -923,39 +970,39 @@ mod tests { let mut client = Client::new(FakeHttpClient::with_404_response()); let server = FakeServer::for_client(user_id, &mut client, &cx).await; - let model = cx.add_model(|_| Model { subscription: None }); - let (mut done_tx1, mut done_rx1) = postage::oneshot::channel(); - let (mut done_tx2, mut done_rx2) = postage::oneshot::channel(); - let _subscription1 = model.update(&mut cx, |_, cx| { - client.add_entity_message_handler( - 1, - cx, - move |_, _: TypedEnvelope, _, _| { - postage::sink::Sink::try_send(&mut done_tx1, ()).unwrap(); - async { Ok(()) } - }, - ) + let (done_tx1, mut done_rx1) = smol::channel::unbounded(); + let (done_tx2, mut done_rx2) = smol::channel::unbounded(); + client.add_entity_message_handler( + move |model: ModelHandle, _: TypedEnvelope, _, cx| { + match model.read_with(&cx, |model, _| model.id) { + 1 => done_tx1.try_send(()).unwrap(), + 2 => done_tx2.try_send(()).unwrap(), + _ => unreachable!(), + } + async { Ok(()) } + }, + ); + let model1 = cx.add_model(|_| Model { + id: 1, + subscription: None, }); - let _subscription2 = model.update(&mut cx, |_, cx| { - client.add_entity_message_handler( - 2, - cx, - move |_, _: TypedEnvelope, _, _| { - postage::sink::Sink::try_send(&mut done_tx2, ()).unwrap(); - async { Ok(()) } - }, - ) + let model2 = cx.add_model(|_| Model { + id: 2, + subscription: None, + }); + let model3 = cx.add_model(|_| Model { + id: 3, + subscription: None, }); + let _subscription1 = + model1.update(&mut cx, |_, cx| client.add_model_for_remote_entity(1, cx)); + let _subscription2 = + model2.update(&mut cx, |_, cx| client.add_model_for_remote_entity(2, cx)); // Ensure dropping a subscription for the same entity type still allows receiving of // messages for other entity IDs of the same type. - let subscription3 = model.update(&mut cx, |_, cx| { - client.add_entity_message_handler( - 3, - cx, - |_, _: TypedEnvelope, _, _| async { Ok(()) }, - ) - }); + let subscription3 = + model3.update(&mut cx, |_, cx| client.add_model_for_remote_entity(3, cx)); drop(subscription3); server.send(proto::UnshareProject { project_id: 1 }); @@ -972,22 +1019,22 @@ mod tests { let mut client = Client::new(FakeHttpClient::with_404_response()); let server = FakeServer::for_client(user_id, &mut client, &cx).await; - let model = cx.add_model(|_| Model { subscription: None }); - let (mut done_tx1, _done_rx1) = postage::oneshot::channel(); - let (mut done_tx2, mut done_rx2) = postage::oneshot::channel(); - let subscription1 = model.update(&mut cx, |_, cx| { - client.add_message_handler(cx, move |_, _: TypedEnvelope, _, _| { - postage::sink::Sink::try_send(&mut done_tx1, ()).unwrap(); + let model = cx.add_model(|_| Model::default()); + let (done_tx1, _done_rx1) = smol::channel::unbounded(); + let (done_tx2, mut done_rx2) = smol::channel::unbounded(); + let subscription1 = client.add_message_handler( + model.clone(), + move |_, _: TypedEnvelope, _, _| { + done_tx1.try_send(()).unwrap(); async { Ok(()) } - }) - }); + }, + ); drop(subscription1); - let _subscription2 = model.update(&mut cx, |_, cx| { - client.add_message_handler(cx, move |_, _: TypedEnvelope, _, _| { - postage::sink::Sink::try_send(&mut done_tx2, ()).unwrap(); + let _subscription2 = + client.add_message_handler(model, move |_, _: TypedEnvelope, _, _| { + done_tx2.try_send(()).unwrap(); async { Ok(()) } - }) - }); + }); server.send(proto::Ping {}); done_rx2.next().await.unwrap(); } @@ -1000,23 +1047,26 @@ mod tests { let mut client = Client::new(FakeHttpClient::with_404_response()); let server = FakeServer::for_client(user_id, &mut client, &cx).await; - let model = cx.add_model(|_| Model { subscription: None }); - let (mut done_tx, mut done_rx) = postage::oneshot::channel(); - model.update(&mut cx, |model, cx| { - model.subscription = Some(client.add_message_handler( - cx, - move |model, _: TypedEnvelope, _, mut cx| { - model.update(&mut cx, |model, _| model.subscription.take()); - postage::sink::Sink::try_send(&mut done_tx, ()).unwrap(); - async { Ok(()) } - }, - )); + let model = cx.add_model(|_| Model::default()); + let (done_tx, mut done_rx) = smol::channel::unbounded(); + let subscription = client.add_message_handler( + model.clone(), + move |model, _: TypedEnvelope, _, mut cx| { + model.update(&mut cx, |model, _| model.subscription.take()); + done_tx.try_send(()).unwrap(); + async { Ok(()) } + }, + ); + model.update(&mut cx, |model, _| { + model.subscription = Some(subscription); }); server.send(proto::Ping {}); done_rx.next().await.unwrap(); } + #[derive(Default)] struct Model { + id: usize, subscription: Option, } diff --git a/crates/client/src/test.rs b/crates/client/src/test.rs index c8aca791923cc57f97c0e05e2c42171bbfe66b3e..697bf3860c224a51cac3096f1358d9780948d6fb 100644 --- a/crates/client/src/test.rs +++ b/crates/client/src/test.rs @@ -1,25 +1,28 @@ -use super::Client; -use super::*; -use crate::http::{HttpClient, Request, Response, ServerResponse}; +use crate::{ + http::{HttpClient, Request, Response, ServerResponse}, + Client, Connection, Credentials, EstablishConnectionError, UserStore, +}; +use anyhow::{anyhow, Result}; use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt}; -use gpui::{ModelHandle, TestAppContext}; +use gpui::{executor, ModelHandle, TestAppContext}; use parking_lot::Mutex; use rpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope}; -use std::fmt; -use std::sync::atomic::Ordering::SeqCst; -use std::sync::{ - atomic::{AtomicBool, AtomicUsize}, - Arc, -}; +use std::{fmt, rc::Rc, sync::Arc}; pub struct FakeServer { peer: Arc, - incoming: Mutex>>>, - connection_id: Mutex>, - forbid_connections: AtomicBool, - auth_count: AtomicUsize, - access_token: AtomicUsize, + state: Arc>, user_id: u64, + executor: Rc, +} + +#[derive(Default)] +struct FakeServerState { + incoming: Option>>, + connection_id: Option, + forbid_connections: bool, + auth_count: usize, + access_token: usize, } impl FakeServer { @@ -27,24 +30,22 @@ impl FakeServer { client_user_id: u64, client: &mut Arc, cx: &TestAppContext, - ) -> Arc { - let server = Arc::new(Self { + ) -> Self { + let server = Self { peer: Peer::new(), - incoming: Default::default(), - connection_id: Default::default(), - forbid_connections: Default::default(), - auth_count: Default::default(), - access_token: Default::default(), + state: Default::default(), user_id: client_user_id, - }); + executor: cx.foreground(), + }; Arc::get_mut(client) .unwrap() .override_authenticate({ - let server = server.clone(); + let state = server.state.clone(); move |cx| { - server.auth_count.fetch_add(1, SeqCst); - let access_token = server.access_token.load(SeqCst).to_string(); + let mut state = state.lock(); + state.auth_count += 1; + let access_token = state.access_token.to_string(); cx.spawn(move |_| async move { Ok(Credentials { user_id: client_user_id, @@ -54,12 +55,32 @@ impl FakeServer { } }) .override_establish_connection({ - let server = server.clone(); + let peer = server.peer.clone(); + let state = server.state.clone(); move |credentials, cx| { + let peer = peer.clone(); + let state = state.clone(); let credentials = credentials.clone(); - cx.spawn({ - let server = server.clone(); - move |cx| async move { server.establish_connection(&credentials, &cx).await } + cx.spawn(move |cx| async move { + assert_eq!(credentials.user_id, client_user_id); + + if state.lock().forbid_connections { + Err(EstablishConnectionError::Other(anyhow!( + "server is forbidding connections" + )))? + } + + if credentials.access_token != state.lock().access_token.to_string() { + Err(EstablishConnectionError::Unauthorized)? + } + + let (client_conn, server_conn, _) = Connection::in_memory(cx.background()); + let (connection_id, io, incoming) = peer.add_connection(server_conn).await; + cx.background().spawn(io).detach(); + let mut state = state.lock(); + state.connection_id = Some(connection_id); + state.incoming = Some(incoming); + Ok(client_conn) }) } }); @@ -73,49 +94,25 @@ impl FakeServer { pub fn disconnect(&self) { self.peer.disconnect(self.connection_id()); - self.connection_id.lock().take(); - self.incoming.lock().take(); - } - - async fn establish_connection( - &self, - credentials: &Credentials, - cx: &AsyncAppContext, - ) -> Result { - assert_eq!(credentials.user_id, self.user_id); - - if self.forbid_connections.load(SeqCst) { - Err(EstablishConnectionError::Other(anyhow!( - "server is forbidding connections" - )))? - } - - if credentials.access_token != self.access_token.load(SeqCst).to_string() { - Err(EstablishConnectionError::Unauthorized)? - } - - let (client_conn, server_conn, _) = Connection::in_memory(cx.background()); - let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await; - cx.background().spawn(io).detach(); - *self.incoming.lock() = Some(incoming); - *self.connection_id.lock() = Some(connection_id); - Ok(client_conn) + let mut state = self.state.lock(); + state.connection_id.take(); + state.incoming.take(); } pub fn auth_count(&self) -> usize { - self.auth_count.load(SeqCst) + self.state.lock().auth_count } pub fn roll_access_token(&self) { - self.access_token.fetch_add(1, SeqCst); + self.state.lock().access_token += 1; } pub fn forbid_connections(&self) { - self.forbid_connections.store(true, SeqCst); + self.state.lock().forbid_connections = true; } pub fn allow_connections(&self) { - self.forbid_connections.store(false, SeqCst); + self.state.lock().forbid_connections = false; } pub fn send(&self, message: T) { @@ -123,14 +120,17 @@ impl FakeServer { } pub async fn receive(&self) -> Result> { + self.executor.start_waiting(); let message = self - .incoming + .state .lock() + .incoming .as_mut() .expect("not connected") .next() .await .ok_or_else(|| anyhow!("other half hung up"))?; + self.executor.finish_waiting(); let type_name = message.payload_type_name(); Ok(*message .into_any() @@ -152,7 +152,7 @@ impl FakeServer { } fn connection_id(&self) -> ConnectionId { - self.connection_id.lock().expect("not connected") + self.state.lock().connection_id.expect("not connected") } pub async fn build_user_store( diff --git a/crates/client/src/user.rs b/crates/client/src/user.rs index 1e4f7fe4d4d5811512bc177fd60151ab31e6d0cf..c318c7f5050fbc7e80125c12289f84d5b8abb20c 100644 --- a/crates/client/src/user.rs +++ b/crates/client/src/user.rs @@ -35,6 +35,7 @@ pub struct ProjectMetadata { pub struct UserStore { users: HashMap>, + update_contacts_tx: watch::Sender>, current_user: watch::Receiver>>, contacts: Arc<[Contact]>, client: Arc, @@ -56,23 +57,19 @@ impl UserStore { cx: &mut ModelContext, ) -> Self { let (mut current_user_tx, current_user_rx) = watch::channel(); - let (mut update_contacts_tx, mut update_contacts_rx) = + let (update_contacts_tx, mut update_contacts_rx) = watch::channel::>(); - let update_contacts_subscription = client.add_message_handler( - cx, - move |_: ModelHandle, msg: TypedEnvelope, _, _| { - *update_contacts_tx.borrow_mut() = Some(msg.payload); - async move { Ok(()) } - }, - ); + let rpc_subscription = + client.add_message_handler(cx.handle(), Self::handle_update_contacts); Self { users: Default::default(), current_user: current_user_rx, contacts: Arc::from([]), client: client.clone(), + update_contacts_tx, http, _maintain_contacts: cx.spawn_weak(|this, mut cx| async move { - let _subscription = update_contacts_subscription; + let _subscription = rpc_subscription; while let Some(message) = update_contacts_rx.recv().await { if let Some((message, this)) = message.zip(this.upgrade(&cx)) { this.update(&mut cx, |this, cx| this.update_contacts(message, cx)) @@ -104,6 +101,18 @@ impl UserStore { } } + async fn handle_update_contacts( + this: ModelHandle, + msg: TypedEnvelope, + _: Arc, + mut cx: AsyncAppContext, + ) -> Result<()> { + this.update(&mut cx, |this, _| { + *this.update_contacts_tx.borrow_mut() = Some(msg.payload); + }); + Ok(()) + } + fn update_contacts( &mut self, message: proto::UpdateContacts, diff --git a/crates/clock/src/clock.rs b/crates/clock/src/clock.rs index 888889871f5330e773d832f217f6ca11b7515cb4..0fdeda0b99427150320f8d8f31e8b0ac212166a0 100644 --- a/crates/clock/src/clock.rs +++ b/crates/clock/src/clock.rs @@ -216,6 +216,16 @@ impl Global { } } +impl FromIterator for Global { + fn from_iter>(locals: T) -> Self { + let mut result = Self::new(); + for local in locals { + result.observe(local); + } + result + } +} + impl Ord for Lamport { fn cmp(&self, other: &Self) -> Ordering { // Use the replica id to break ties between concurrent events. diff --git a/crates/diagnostics/src/items.rs b/crates/diagnostics/src/items.rs index 80291cde3dc58d36dbeeb84a144f9e271ffd43e0..7949fe952c19e4982675bc492d38a609da1d8b5b 100644 --- a/crates/diagnostics/src/items.rs +++ b/crates/diagnostics/src/items.rs @@ -57,7 +57,7 @@ impl View for DiagnosticSummary { let theme = &self.settings.borrow().theme.project_diagnostics; let in_progress = self.in_progress; - MouseEventHandler::new::(0, cx, |_, _| { + MouseEventHandler::new::(cx.view_id(), cx, |_, _| { if in_progress { Label::new( "Checking... ".to_string(), diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index f2ce168a09559ea40fc283c1f5c7171f897e4c4a..cb2e26b9f8feb73b4e6f232770da9b2a5c572f4d 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -2264,7 +2264,7 @@ impl Editor { enum Tag {} let style = (self.build_settings)(cx).style; Some( - MouseEventHandler::new::(0, cx, |_, _| { + MouseEventHandler::new::(cx.view_id(), cx, |_, _| { Svg::new("icons/zap.svg") .with_color(style.code_actions_indicator) .boxed() @@ -5447,8 +5447,8 @@ mod tests { use super::*; use language::LanguageConfig; use lsp::FakeLanguageServer; - use postage::prelude::Stream; use project::{FakeFs, ProjectPath}; + use smol::stream::StreamExt; use std::{cell::RefCell, rc::Rc, time::Instant}; use text::Point; use unindent::Unindent; @@ -7737,9 +7737,8 @@ mod tests { }), ..Default::default() }, - &cx, - ) - .await; + cx.background(), + ); let text = " one @@ -7792,7 +7791,9 @@ mod tests { ], ) .await; - editor.next_notification(&cx).await; + editor + .condition(&cx, |editor, _| editor.context_menu_visible()) + .await; let apply_additional_edits = editor.update(&mut cx, |editor, cx| { editor.move_down(&MoveDown, cx); @@ -7856,7 +7857,7 @@ mod tests { ) .await; editor - .condition(&cx, |editor, _| editor.context_menu.is_some()) + .condition(&cx, |editor, _| editor.context_menu_visible()) .await; editor.update(&mut cx, |editor, cx| { @@ -7874,7 +7875,9 @@ mod tests { ], ) .await; - editor.next_notification(&cx).await; + editor + .condition(&cx, |editor, _| editor.context_menu_visible()) + .await; let apply_additional_edits = editor.update(&mut cx, |editor, cx| { let apply_additional_edits = editor @@ -7912,7 +7915,7 @@ mod tests { ); Some(lsp::CompletionResponse::Array( completions - .into_iter() + .iter() .map(|(range, new_text)| lsp::CompletionItem { label: new_text.to_string(), text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit { @@ -7927,7 +7930,7 @@ mod tests { .collect(), )) }) - .recv() + .next() .await; } @@ -7937,7 +7940,7 @@ mod tests { ) { fake.handle_request::(move |_| { lsp::CompletionItem { - additional_text_edits: edit.map(|(range, new_text)| { + additional_text_edits: edit.clone().map(|(range, new_text)| { vec![lsp::TextEdit::new( lsp::Range::new( lsp::Position::new(range.start.row, range.start.column), @@ -7949,7 +7952,7 @@ mod tests { ..Default::default() } }) - .recv() + .next() .await; } } diff --git a/crates/fuzzy/src/char_bag.rs b/crates/fuzzy/src/char_bag.rs index c9aab0cd0bab0c39d5bc6da6873a9b377f00d259..135c5a768e0d7b0d3ddc16ebeafe614efb7c2b3b 100644 --- a/crates/fuzzy/src/char_bag.rs +++ b/crates/fuzzy/src/char_bag.rs @@ -1,6 +1,6 @@ use std::iter::FromIterator; -#[derive(Copy, Clone, Debug, Default)] +#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)] pub struct CharBag(u64); impl CharBag { diff --git a/crates/gpui/src/app.rs b/crates/gpui/src/app.rs index 1ba05c7b4ac2df25e10915f84d1812db388da74c..af07b9eca551ec456828ae57da64d1db1263f15f 100644 --- a/crates/gpui/src/app.rs +++ b/crates/gpui/src/app.rs @@ -84,6 +84,8 @@ pub trait UpgradeModelHandle { &self, handle: &WeakModelHandle, ) -> Option>; + + fn upgrade_any_model_handle(&self, handle: &AnyWeakModelHandle) -> Option; } pub trait UpgradeViewHandle { @@ -474,6 +476,10 @@ impl TestAppContext { self.cx.borrow().cx.font_cache.clone() } + pub fn foreground_platform(&self) -> Rc { + self.foreground_platform.clone() + } + pub fn platform(&self) -> Arc { self.cx.borrow().cx.platform.clone() } @@ -486,6 +492,15 @@ impl TestAppContext { self.cx.borrow().background().clone() } + pub fn spawn(&self, f: F) -> Task + where + F: FnOnce(AsyncAppContext) -> Fut, + Fut: 'static + Future, + T: 'static, + { + self.cx.borrow_mut().spawn(f) + } + pub fn simulate_new_path_selection(&self, result: impl FnOnce(PathBuf) -> Option) { self.foreground_platform.simulate_new_path_selection(result); } @@ -566,7 +581,11 @@ impl UpgradeModelHandle for AsyncAppContext { &self, handle: &WeakModelHandle, ) -> Option> { - self.0.borrow_mut().upgrade_model_handle(handle) + self.0.borrow().upgrade_model_handle(handle) + } + + fn upgrade_any_model_handle(&self, handle: &AnyWeakModelHandle) -> Option { + self.0.borrow().upgrade_any_model_handle(handle) } } @@ -685,6 +704,7 @@ pub struct MutableAppContext { next_entity_id: usize, next_window_id: usize, next_subscription_id: usize, + frame_count: usize, subscriptions: Arc>>>, observations: Arc>>>, release_observations: Arc>>>, @@ -729,6 +749,7 @@ impl MutableAppContext { next_entity_id: 0, next_window_id: 0, next_subscription_id: 0, + frame_count: 0, subscriptions: Default::default(), observations: Default::default(), release_observations: Default::default(), @@ -920,6 +941,7 @@ impl MutableAppContext { window_id: usize, titlebar_height: f32, ) -> HashMap { + self.start_frame(); let view_ids = self .views .keys() @@ -943,6 +965,10 @@ impl MutableAppContext { .collect() } + pub(crate) fn start_frame(&mut self) { + self.frame_count += 1; + } + pub fn update T>(&mut self, callback: F) -> T { self.pending_flushes += 1; let result = callback(self); @@ -1397,7 +1423,12 @@ impl MutableAppContext { .element_states .entry(key) .or_insert_with(|| Box::new(T::default())); - ElementStateHandle::new(TypeId::of::(), id, &self.cx.ref_counts) + ElementStateHandle::new( + TypeId::of::(), + id, + self.frame_count, + &self.cx.ref_counts, + ) } fn remove_dropped_entities(&mut self) { @@ -1748,6 +1779,10 @@ impl UpgradeModelHandle for MutableAppContext { ) -> Option> { self.cx.upgrade_model_handle(handle) } + + fn upgrade_any_model_handle(&self, handle: &AnyWeakModelHandle) -> Option { + self.cx.upgrade_any_model_handle(handle) + } } impl UpgradeViewHandle for MutableAppContext { @@ -1872,6 +1907,19 @@ impl UpgradeModelHandle for AppContext { None } } + + fn upgrade_any_model_handle(&self, handle: &AnyWeakModelHandle) -> Option { + if self.models.contains_key(&handle.model_id) { + self.ref_counts.lock().inc_model(handle.model_id); + Some(AnyModelHandle { + model_id: handle.model_id, + model_type: handle.model_type, + ref_counts: self.ref_counts.clone(), + }) + } else { + None + } + } } impl UpgradeViewHandle for AppContext { @@ -2264,6 +2312,10 @@ impl UpgradeModelHandle for ModelContext<'_, M> { ) -> Option> { self.cx.upgrade_model_handle(handle) } + + fn upgrade_any_model_handle(&self, handle: &AnyWeakModelHandle) -> Option { + self.cx.upgrade_any_model_handle(handle) + } } impl Deref for ModelContext<'_, M> { @@ -2594,6 +2646,10 @@ impl UpgradeModelHandle for ViewContext<'_, V> { ) -> Option> { self.cx.upgrade_model_handle(handle) } + + fn upgrade_any_model_handle(&self, handle: &AnyWeakModelHandle) -> Option { + self.cx.upgrade_any_model_handle(handle) + } } impl UpgradeViewHandle for ViewContext<'_, V> { @@ -3274,6 +3330,13 @@ impl AnyModelHandle { } } + pub fn downgrade(&self) -> AnyWeakModelHandle { + AnyWeakModelHandle { + model_id: self.model_id, + model_type: self.model_type, + } + } + pub fn is(&self) -> bool { self.model_type == TypeId::of::() } @@ -3290,12 +3353,34 @@ impl From> for AnyModelHandle { } } +impl Clone for AnyModelHandle { + fn clone(&self) -> Self { + self.ref_counts.lock().inc_model(self.model_id); + Self { + model_id: self.model_id, + model_type: self.model_type, + ref_counts: self.ref_counts.clone(), + } + } +} + impl Drop for AnyModelHandle { fn drop(&mut self) { self.ref_counts.lock().dec_model(self.model_id); } } +pub struct AnyWeakModelHandle { + model_id: usize, + model_type: TypeId, +} + +impl AnyWeakModelHandle { + pub fn upgrade(&self, cx: &impl UpgradeModelHandle) -> Option { + cx.upgrade_any_model_handle(self) + } +} + pub struct WeakViewHandle { window_id: usize, view_id: usize, @@ -3368,8 +3453,15 @@ pub struct ElementStateHandle { } impl ElementStateHandle { - fn new(tag_type_id: TypeId, id: ElementStateId, ref_counts: &Arc>) -> Self { - ref_counts.lock().inc_element_state(tag_type_id, id); + fn new( + tag_type_id: TypeId, + id: ElementStateId, + frame_id: usize, + ref_counts: &Arc>, + ) -> Self { + ref_counts + .lock() + .inc_element_state(tag_type_id, id, frame_id); Self { value_type: PhantomData, tag_type_id, @@ -3508,12 +3600,17 @@ impl Drop for Subscription { #[derive(Default)] struct RefCounts { entity_counts: HashMap, - element_state_counts: HashMap<(TypeId, ElementStateId), usize>, + element_state_counts: HashMap<(TypeId, ElementStateId), ElementStateRefCount>, dropped_models: HashSet, dropped_views: HashSet<(usize, usize)>, dropped_element_states: HashSet<(TypeId, ElementStateId)>, } +struct ElementStateRefCount { + ref_count: usize, + frame_id: usize, +} + impl RefCounts { fn inc_model(&mut self, model_id: usize) { match self.entity_counts.entry(model_id) { @@ -3537,11 +3634,21 @@ impl RefCounts { } } - fn inc_element_state(&mut self, tag_type_id: TypeId, id: ElementStateId) { + fn inc_element_state(&mut self, tag_type_id: TypeId, id: ElementStateId, frame_id: usize) { match self.element_state_counts.entry((tag_type_id, id)) { - Entry::Occupied(mut entry) => *entry.get_mut() += 1, + Entry::Occupied(mut entry) => { + let entry = entry.get_mut(); + if entry.frame_id == frame_id || entry.ref_count >= 2 { + panic!("used the same element state more than once in the same frame"); + } + entry.ref_count += 1; + entry.frame_id = frame_id; + } Entry::Vacant(entry) => { - entry.insert(1); + entry.insert(ElementStateRefCount { + ref_count: 1, + frame_id, + }); self.dropped_element_states.remove(&(tag_type_id, id)); } } @@ -3567,9 +3674,9 @@ impl RefCounts { fn dec_element_state(&mut self, tag_type_id: TypeId, id: ElementStateId) { let key = (tag_type_id, id); - let count = self.element_state_counts.get_mut(&key).unwrap(); - *count -= 1; - if *count == 0 { + let entry = self.element_state_counts.get_mut(&key).unwrap(); + entry.ref_count -= 1; + if entry.ref_count == 0 { self.element_state_counts.remove(&key); self.dropped_element_states.insert(key); } diff --git a/crates/gpui/src/elements/uniform_list.rs b/crates/gpui/src/elements/uniform_list.rs index 9248a8d146e07a4b85f627644fa696dd7a616a9a..4fbb9ca420c27520b5032fd659d6eaac3ab77019 100644 --- a/crates/gpui/src/elements/uniform_list.rs +++ b/crates/gpui/src/elements/uniform_list.rs @@ -162,7 +162,6 @@ where "UniformList does not support being rendered with an unconstrained height" ); } - let mut items = Vec::new(); if self.item_count == 0 { return ( @@ -170,22 +169,27 @@ where LayoutState { item_height: 0., scroll_max: 0., - items, + items: Default::default(), }, ); } + let mut items = Vec::new(); let mut size = constraint.max; let mut item_size; - if let Some(sample_item_ix) = self.get_width_from_item { - (self.append_items)(sample_item_ix..sample_item_ix + 1, &mut items, cx); - let sample_item = items.get_mut(0).unwrap(); + let sample_item_ix; + let mut sample_item; + if let Some(sample_ix) = self.get_width_from_item { + (self.append_items)(sample_ix..sample_ix + 1, &mut items, cx); + sample_item_ix = sample_ix; + sample_item = items.pop().unwrap(); item_size = sample_item.layout(constraint, cx); size.set_x(item_size.x()); } else { (self.append_items)(0..1, &mut items, cx); - let first_item = items.first_mut().unwrap(); - item_size = first_item.layout( + sample_item_ix = 0; + sample_item = items.pop().unwrap(); + item_size = sample_item.layout( SizeConstraint::new( vec2f(constraint.max.x(), 0.0), vec2f(constraint.max.x(), f32::INFINITY), @@ -219,8 +223,21 @@ where self.item_count, start + (size.y() / item_height).ceil() as usize + 1, ); - items.clear(); - (self.append_items)(start..end, &mut items, cx); + + if (start..end).contains(&sample_item_ix) { + if sample_item_ix > start { + (self.append_items)(start..sample_item_ix, &mut items, cx); + } + + items.push(sample_item); + + if sample_item_ix < end { + (self.append_items)(sample_item_ix + 1..end, &mut items, cx); + } + } else { + (self.append_items)(start..end, &mut items, cx); + } + for item in &mut items { let item_size = item.layout(item_constraint, cx); if item_size.x() > size.x() { diff --git a/crates/gpui/src/executor.rs b/crates/gpui/src/executor.rs index e9e71e4b72ac348e7adf1db8d0ecfe4db732fdce..24cc60b996addd6fe20f6ad5df35cf7cdcc2bcbc 100644 --- a/crates/gpui/src/executor.rs +++ b/crates/gpui/src/executor.rs @@ -370,6 +370,13 @@ impl Foreground { *any_value.downcast().unwrap() } + pub fn run_until_parked(&self) { + match self { + Self::Deterministic { executor, .. } => executor.run_until_parked(), + _ => panic!("this method can only be called on a deterministic executor"), + } + } + pub fn parking_forbidden(&self) -> bool { match self { Self::Deterministic { executor, .. } => executor.state.lock().forbid_parking, diff --git a/crates/gpui/src/presenter.rs b/crates/gpui/src/presenter.rs index 1b5adbc994db269bfb30af77196d334ba35362bd..f49081ae2dd944728c957fffba374fdca8d2c9e5 100644 --- a/crates/gpui/src/presenter.rs +++ b/crates/gpui/src/presenter.rs @@ -6,9 +6,9 @@ use crate::{ json::{self, ToJson}, platform::Event, text_layout::TextLayoutCache, - Action, AnyAction, AnyViewHandle, AssetCache, ElementBox, Entity, FontSystem, ModelHandle, - ReadModel, ReadView, Scene, UpgradeModelHandle, UpgradeViewHandle, View, ViewHandle, - WeakModelHandle, WeakViewHandle, + Action, AnyAction, AnyModelHandle, AnyViewHandle, AnyWeakModelHandle, AssetCache, ElementBox, + Entity, FontSystem, ModelHandle, ReadModel, ReadView, Scene, UpgradeModelHandle, + UpgradeViewHandle, View, ViewHandle, WeakModelHandle, WeakViewHandle, }; use pathfinder_geometry::vector::{vec2f, Vector2F}; use serde_json::json; @@ -62,6 +62,7 @@ impl Presenter { } pub fn invalidate(&mut self, mut invalidation: WindowInvalidation, cx: &mut MutableAppContext) { + cx.start_frame(); for view_id in invalidation.removed { invalidation.updated.remove(&view_id); self.rendered_views.remove(&view_id); @@ -81,6 +82,7 @@ impl Presenter { invalidation: Option, cx: &mut MutableAppContext, ) { + cx.start_frame(); if let Some(invalidation) = invalidation { for view_id in invalidation.removed { self.rendered_views.remove(&view_id); @@ -278,6 +280,10 @@ impl<'a> UpgradeModelHandle for LayoutContext<'a> { ) -> Option> { self.app.upgrade_model_handle(handle) } + + fn upgrade_any_model_handle(&self, handle: &AnyWeakModelHandle) -> Option { + self.app.upgrade_any_model_handle(handle) + } } impl<'a> UpgradeViewHandle for LayoutContext<'a> { diff --git a/crates/gpui/src/test.rs b/crates/gpui/src/test.rs index af6430d36c7901dbea1a273996cd1f6536ef365b..ac5ca0e86682080506e2e91f02f4fb9816932ca7 100644 --- a/crates/gpui/src/test.rs +++ b/crates/gpui/src/test.rs @@ -33,6 +33,7 @@ pub fn run_test( Rc, Arc, u64, + bool, )), ) { let is_randomized = num_iterations > 1; @@ -56,10 +57,8 @@ pub fn run_test( let font_cache = Arc::new(FontCache::new(font_system)); loop { - let seed = atomic_seed.load(SeqCst); - if seed >= starting_seed + num_iterations { - break; - } + let seed = atomic_seed.fetch_add(1, SeqCst); + let is_last_iteration = seed + 1 >= starting_seed + num_iterations; if is_randomized { dbg!(seed); @@ -74,9 +73,19 @@ pub fn run_test( font_cache.clone(), 0, ); - cx.update(|cx| test_fn(cx, foreground_platform.clone(), deterministic, seed)); - - atomic_seed.fetch_add(1, SeqCst); + cx.update(|cx| { + test_fn( + cx, + foreground_platform.clone(), + deterministic, + seed, + is_last_iteration, + ) + }); + + if is_last_iteration { + break; + } } }); @@ -90,7 +99,7 @@ pub fn run_test( println!("retrying: attempt {}", retries); } else { if is_randomized { - eprintln!("failing seed: {}", atomic_seed.load(SeqCst)); + eprintln!("failing seed: {}", atomic_seed.load(SeqCst) - 1); } panic::resume_unwind(error); } diff --git a/crates/gpui_macros/src/gpui_macros.rs b/crates/gpui_macros/src/gpui_macros.rs index 21d978d9fb3e56a6a212ee8d5123cf095266f452..885cc8311a42fd788b69b26c9e82400de5a51546 100644 --- a/crates/gpui_macros/src/gpui_macros.rs +++ b/crates/gpui_macros/src/gpui_macros.rs @@ -85,7 +85,10 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream { )); } Some("StdRng") => { - inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed))); + inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed),)); + } + Some("bool") => { + inner_fn_args.extend(quote!(is_last_iteration,)); } _ => { return TokenStream::from( @@ -115,7 +118,9 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream { #num_iterations as u64, #starting_seed as u64, #max_retries, - &mut |cx, foreground_platform, deterministic, seed| cx.foreground().run(#inner_fn_name(#inner_fn_args)) + &mut |cx, foreground_platform, deterministic, seed, is_last_iteration| { + cx.foreground().run(#inner_fn_name(#inner_fn_args)) + } ); } } @@ -125,8 +130,14 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream { if let FnArg::Typed(arg) = arg { if let Type::Path(ty) = &*arg.ty { let last_segment = ty.path.segments.last(); - if let Some("StdRng") = last_segment.map(|s| s.ident.to_string()).as_deref() { - inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed),)); + match last_segment.map(|s| s.ident.to_string()).as_deref() { + Some("StdRng") => { + inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed),)); + } + Some("bool") => { + inner_fn_args.extend(quote!(is_last_iteration,)); + } + _ => {} } } else { inner_fn_args.extend(quote!(cx,)); @@ -147,7 +158,7 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream { #num_iterations as u64, #starting_seed as u64, #max_retries, - &mut |cx, _, _, seed| #inner_fn_name(#inner_fn_args) + &mut |cx, _, _, seed, is_last_iteration| #inner_fn_name(#inner_fn_args) ); } } diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index f19ff21081af8b2fb9773e8ea3ae6f6496b7c155..b4543b02b02d98cfb06ecf75bbc66391b0447d47 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -1283,6 +1283,10 @@ impl Buffer { self.text.wait_for_edits(edit_ids) } + pub fn wait_for_version(&mut self, version: clock::Global) -> impl Future { + self.text.wait_for_version(version) + } + pub fn set_active_selections( &mut self, selections: Arc<[Selection]>, diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index bf24118aff1ddbacb87513408b37f79a10e205e9..73de5af12c3db0b2c4d541700225e20d4f9b8e6e 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -17,6 +17,9 @@ use std::{cell::RefCell, ops::Range, path::Path, str, sync::Arc}; use theme::SyntaxTheme; use tree_sitter::{self, Query}; +#[cfg(any(test, feature = "test-support"))] +use futures::channel::mpsc; + pub use buffer::Operation; pub use buffer::*; pub use diagnostic_set::DiagnosticEntry; @@ -79,7 +82,14 @@ pub struct LanguageServerConfig { pub disk_based_diagnostics_progress_token: Option, #[cfg(any(test, feature = "test-support"))] #[serde(skip)] - pub fake_server: Option<(Arc, Arc)>, + fake_config: Option, +} + +#[cfg(any(test, feature = "test-support"))] +struct FakeLanguageServerConfig { + servers_tx: mpsc::UnboundedSender, + capabilities: lsp::ServerCapabilities, + initializer: Option>, } #[derive(Clone, Debug, Deserialize)] @@ -224,8 +234,27 @@ impl Language { ) -> Result>> { if let Some(config) = &self.config.language_server { #[cfg(any(test, feature = "test-support"))] - if let Some((server, started)) = &config.fake_server { - started.store(true, std::sync::atomic::Ordering::SeqCst); + if let Some(fake_config) = &config.fake_config { + use postage::prelude::Stream; + + let (server, mut fake_server) = lsp::LanguageServer::fake_with_capabilities( + fake_config.capabilities.clone(), + cx.background().clone(), + ); + + if let Some(initalizer) = &fake_config.initializer { + initalizer(&mut fake_server); + } + + let servers_tx = fake_config.servers_tx.clone(); + let mut initialized = server.capabilities(); + cx.background() + .spawn(async move { + while initialized.recv().await.is_none() {} + servers_tx.unbounded_send(fake_server).ok(); + }) + .detach(); + return Ok(Some(server.clone())); } @@ -357,27 +386,32 @@ impl CompletionLabel { #[cfg(any(test, feature = "test-support"))] impl LanguageServerConfig { - pub async fn fake(cx: &gpui::TestAppContext) -> (Self, lsp::FakeLanguageServer) { - Self::fake_with_capabilities(Default::default(), cx).await - } - - pub async fn fake_with_capabilities( - capabilites: lsp::ServerCapabilities, - cx: &gpui::TestAppContext, - ) -> (Self, lsp::FakeLanguageServer) { - let (server, fake) = lsp::LanguageServer::fake_with_capabilities(capabilites, cx).await; - fake.started - .store(false, std::sync::atomic::Ordering::SeqCst); - let started = fake.started.clone(); + pub fn fake() -> (Self, mpsc::UnboundedReceiver) { + let (servers_tx, servers_rx) = mpsc::unbounded(); ( Self { - fake_server: Some((server, started)), + fake_config: Some(FakeLanguageServerConfig { + servers_tx, + capabilities: Default::default(), + initializer: None, + }), disk_based_diagnostics_progress_token: Some("fakeServer/check".to_string()), ..Default::default() }, - fake, + servers_rx, ) } + + pub fn set_fake_capabilities(&mut self, capabilities: lsp::ServerCapabilities) { + self.fake_config.as_mut().unwrap().capabilities = capabilities; + } + + pub fn set_fake_initializer( + &mut self, + initializer: impl 'static + Send + Sync + Fn(&mut lsp::FakeLanguageServer), + ) { + self.fake_config.as_mut().unwrap().initializer = Some(Box::new(initializer)); + } } impl ToLspPosition for PointUtf16 { diff --git a/crates/language/src/tests.rs b/crates/language/src/tests.rs index c7ea90714cdb474bd42018f6a249e5c8029c47c4..0ae1fbe7074ce341bd4e9ea45b4336096b572df5 100644 --- a/crates/language/src/tests.rs +++ b/crates/language/src/tests.rs @@ -557,7 +557,7 @@ fn test_autoindent_adjusts_lines_when_only_text_changes(cx: &mut MutableAppConte #[gpui::test] async fn test_diagnostics(mut cx: gpui::TestAppContext) { - let (language_server, mut fake) = lsp::LanguageServer::fake(&cx).await; + let (language_server, mut fake) = lsp::LanguageServer::fake(cx.background()); let mut rust_lang = rust_lang(); rust_lang.config.language_server = Some(LanguageServerConfig { disk_based_diagnostic_sources: HashSet::from_iter(["disk".to_string()]), @@ -840,7 +840,7 @@ async fn test_diagnostics(mut cx: gpui::TestAppContext) { #[gpui::test] async fn test_edits_from_lsp_with_past_version(mut cx: gpui::TestAppContext) { - let (language_server, mut fake) = lsp::LanguageServer::fake(&cx).await; + let (language_server, mut fake) = lsp::LanguageServer::fake(cx.background()); let text = " fn a() { diff --git a/crates/lsp/src/lsp.rs b/crates/lsp/src/lsp.rs index 73f4fe698b8182dccdcdcaa49a9d699bd357ee94..0281e8cd8bd17080c0f8aa68a2035682050a9147 100644 --- a/crates/lsp/src/lsp.rs +++ b/crates/lsp/src/lsp.rs @@ -420,7 +420,9 @@ impl LanguageServer { anyhow!("tried to send a request to a language server that has been shut down") }) .and_then(|outbound_tx| { - outbound_tx.try_send(message)?; + outbound_tx + .try_send(message) + .context("failed to write to language server's stdin")?; Ok(()) }); async move { @@ -481,43 +483,36 @@ impl Drop for Subscription { #[cfg(any(test, feature = "test-support"))] pub struct FakeLanguageServer { - handlers: Arc< - Mutex< - HashMap< - &'static str, - Box (Vec, barrier::Sender)>, - >, - >, - >, - outgoing_tx: channel::Sender>, - incoming_rx: channel::Receiver>, - pub started: Arc, + handlers: + Arc Vec>>>>, + outgoing_tx: futures::channel::mpsc::UnboundedSender>, + incoming_rx: futures::channel::mpsc::UnboundedReceiver>, } #[cfg(any(test, feature = "test-support"))] impl LanguageServer { - pub async fn fake(cx: &gpui::TestAppContext) -> (Arc, FakeLanguageServer) { - Self::fake_with_capabilities(Default::default(), cx).await + pub fn fake(executor: Arc) -> (Arc, FakeLanguageServer) { + Self::fake_with_capabilities(Default::default(), executor) } - pub async fn fake_with_capabilities( + pub fn fake_with_capabilities( capabilities: ServerCapabilities, - cx: &gpui::TestAppContext, + executor: Arc, ) -> (Arc, FakeLanguageServer) { let (stdin_writer, stdin_reader) = async_pipe::pipe(); let (stdout_writer, stdout_reader) = async_pipe::pipe(); - let mut fake = FakeLanguageServer::new(cx, stdin_reader, stdout_writer); - fake.handle_request::(move |_| InitializeResult { - capabilities, - ..Default::default() + let mut fake = FakeLanguageServer::new(executor.clone(), stdin_reader, stdout_writer); + fake.handle_request::({ + let capabilities = capabilities.clone(); + move |_| InitializeResult { + capabilities: capabilities.clone(), + ..Default::default() + } }); let server = - Self::new_internal(stdin_writer, stdout_reader, Path::new("/"), cx.background()) - .unwrap(); - fake.receive_notification::() - .await; + Self::new_internal(stdin_writer, stdout_reader, Path::new("/"), executor).unwrap(); (server, fake) } @@ -526,63 +521,59 @@ impl LanguageServer { #[cfg(any(test, feature = "test-support"))] impl FakeLanguageServer { fn new( - cx: &gpui::TestAppContext, + background: Arc, stdin: async_pipe::PipeReader, stdout: async_pipe::PipeWriter, ) -> Self { use futures::StreamExt as _; - let (incoming_tx, incoming_rx) = channel::unbounded(); - let (outgoing_tx, mut outgoing_rx) = channel::unbounded(); + let (incoming_tx, incoming_rx) = futures::channel::mpsc::unbounded(); + let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded(); let this = Self { outgoing_tx: outgoing_tx.clone(), incoming_rx, handlers: Default::default(), - started: Arc::new(std::sync::atomic::AtomicBool::new(true)), }; // Receive incoming messages let handlers = this.handlers.clone(); - cx.background() + let executor = background.clone(); + background .spawn(async move { let mut buffer = Vec::new(); let mut stdin = smol::io::BufReader::new(stdin); while Self::receive(&mut stdin, &mut buffer).await.is_ok() { - if let Ok(request) = serde_json::from_slice::(&mut buffer) { + executor.simulate_random_delay().await; + if let Ok(request) = serde_json::from_slice::(&buffer) { assert_eq!(request.jsonrpc, JSON_RPC_VERSION); - let handler = handlers.lock().remove(request.method); - if let Some(handler) = handler { - let (response, sent) = - handler(request.id, request.params.get().as_bytes()); + if let Some(handler) = handlers.lock().get_mut(request.method) { + let response = handler(request.id, request.params.get().as_bytes()); log::debug!("handled lsp request. method:{}", request.method); - outgoing_tx.send(response).await.unwrap(); - drop(sent); + outgoing_tx.unbounded_send(response)?; } else { log::debug!("unhandled lsp request. method:{}", request.method); - outgoing_tx - .send( - serde_json::to_vec(&AnyResponse { - id: request.id, - error: Some(Error { - message: "no handler".to_string(), - }), - result: None, - }) - .unwrap(), - ) - .await - .unwrap(); + outgoing_tx.unbounded_send( + serde_json::to_vec(&AnyResponse { + id: request.id, + error: Some(Error { + message: "no handler".to_string(), + }), + result: None, + }) + .unwrap(), + )?; } } else { - incoming_tx.send(buffer.clone()).await.unwrap(); + incoming_tx.unbounded_send(buffer.clone())?; } } + Ok::<_, anyhow::Error>(()) }) .detach(); // Send outgoing messages - cx.background() + background .spawn(async move { let mut stdout = smol::io::BufWriter::new(stdout); while let Some(notification) = outgoing_rx.next().await { @@ -595,16 +586,13 @@ impl FakeLanguageServer { } pub async fn notify(&mut self, params: T::Params) { - if !self.started.load(std::sync::atomic::Ordering::SeqCst) { - panic!("can't simulate an LSP notification before the server has been started"); - } let message = serde_json::to_vec(&Notification { jsonrpc: JSON_RPC_VERSION, method: T::METHOD, params, }) .unwrap(); - self.outgoing_tx.send(message).await.unwrap(); + self.outgoing_tx.unbounded_send(message).unwrap(); } pub async fn receive_notification(&mut self) -> T::Params { @@ -624,15 +612,18 @@ impl FakeLanguageServer { } } - pub fn handle_request(&mut self, handler: F) -> barrier::Receiver + pub fn handle_request( + &mut self, + mut handler: F, + ) -> futures::channel::mpsc::UnboundedReceiver<()> where T: 'static + request::Request, - F: 'static + Send + FnOnce(T::Params) -> T::Result, + F: 'static + Send + Sync + FnMut(T::Params) -> T::Result, { - let (responded_tx, responded_rx) = barrier::channel(); - let prev_handler = self.handlers.lock().insert( + let (responded_tx, responded_rx) = futures::channel::mpsc::unbounded(); + self.handlers.lock().insert( T::METHOD, - Box::new(|id, params| { + Box::new(move |id, params| { let result = handler(serde_json::from_slice::(params).unwrap()); let result = serde_json::to_string(&result).unwrap(); let result = serde_json::from_str::<&RawValue>(&result).unwrap(); @@ -641,18 +632,20 @@ impl FakeLanguageServer { error: None, result: Some(result), }; - (serde_json::to_vec(&response).unwrap(), responded_tx) + responded_tx.unbounded_send(()).ok(); + serde_json::to_vec(&response).unwrap() }), ); - if prev_handler.is_some() { - panic!( - "registered a new handler for LSP method '{}' before the previous handler was called", - T::METHOD - ); - } responded_rx } + pub fn remove_request_handler(&mut self) + where + T: 'static + request::Request, + { + self.handlers.lock().remove(T::METHOD); + } + pub async fn start_progress(&mut self, token: impl Into) { self.notify::(ProgressParams { token: NumberOrString::String(token.into()), @@ -777,7 +770,7 @@ mod tests { #[gpui::test] async fn test_fake(cx: TestAppContext) { - let (server, mut fake) = LanguageServer::fake(&cx).await; + let (server, mut fake) = LanguageServer::fake(cx.background()); let (message_tx, message_rx) = channel::unbounded(); let (diagnostics_tx, diagnostics_rx) = channel::unbounded(); diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index df339477aa9860adee4721d1b821e59a93807370..208604bd09587d72b83a94c093db8402b841f07f 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -2,7 +2,7 @@ pub mod fs; mod ignore; pub mod worktree; -use anyhow::{anyhow, Result}; +use anyhow::{anyhow, Context, Result}; use client::{proto, Client, PeerId, TypedEnvelope, User, UserStore}; use clock::ReplicaId; use collections::{hash_map, HashMap, HashSet}; @@ -10,17 +10,17 @@ use futures::Future; use fuzzy::{PathMatch, PathMatchCandidate, PathMatchCandidateSet}; use gpui::{ AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, - WeakModelHandle, + UpgradeModelHandle, WeakModelHandle, }; use language::{ point_from_lsp, proto::{deserialize_anchor, serialize_anchor}, range_from_lsp, AnchorRangeExt, Bias, Buffer, CodeAction, Completion, CompletionLabel, - Diagnostic, DiagnosticEntry, File as _, Language, LanguageRegistry, PointUtf16, ToLspPosition, - ToOffset, ToPointUtf16, Transaction, + Diagnostic, DiagnosticEntry, File as _, Language, LanguageRegistry, Operation, PointUtf16, + ToLspPosition, ToOffset, ToPointUtf16, Transaction, }; use lsp::{DiagnosticSeverity, LanguageServer}; -use postage::{prelude::Stream, watch}; +use postage::{broadcast, prelude::Stream, sink::Sink, watch}; use smol::block_on; use std::{ convert::TryInto, @@ -46,7 +46,8 @@ pub struct Project { collaborators: HashMap, subscriptions: Vec, language_servers_with_diagnostics_running: isize, - open_buffers: HashMap>, + open_buffers: HashMap, + opened_buffer: broadcast::Sender<()>, loading_buffers: HashMap< ProjectPath, postage::watch::Receiver, Arc>>>, @@ -54,6 +55,11 @@ pub struct Project { shared_buffers: HashMap>>, } +enum OpenBuffer { + Loaded(WeakModelHandle), + Loading(Vec), +} + enum WorktreeHandle { Strong(ModelHandle), Weak(WeakModelHandle), @@ -155,6 +161,31 @@ pub struct ProjectEntry { } impl Project { + pub fn init(client: &Arc) { + client.add_entity_message_handler(Self::handle_add_collaborator); + client.add_entity_message_handler(Self::handle_buffer_reloaded); + client.add_entity_message_handler(Self::handle_buffer_saved); + client.add_entity_message_handler(Self::handle_close_buffer); + client.add_entity_message_handler(Self::handle_disk_based_diagnostics_updated); + client.add_entity_message_handler(Self::handle_disk_based_diagnostics_updating); + client.add_entity_message_handler(Self::handle_remove_collaborator); + client.add_entity_message_handler(Self::handle_share_worktree); + client.add_entity_message_handler(Self::handle_unregister_worktree); + client.add_entity_message_handler(Self::handle_unshare_project); + client.add_entity_message_handler(Self::handle_update_buffer_file); + client.add_entity_message_handler(Self::handle_update_buffer); + client.add_entity_message_handler(Self::handle_update_diagnostic_summary); + client.add_entity_message_handler(Self::handle_update_worktree); + client.add_entity_request_handler(Self::handle_apply_additional_edits_for_completion); + client.add_entity_request_handler(Self::handle_apply_code_action); + client.add_entity_request_handler(Self::handle_format_buffers); + client.add_entity_request_handler(Self::handle_get_code_actions); + client.add_entity_request_handler(Self::handle_get_completions); + client.add_entity_request_handler(Self::handle_get_definition); + client.add_entity_request_handler(Self::handle_open_buffer); + client.add_entity_request_handler(Self::handle_save_buffer); + } + pub fn local( client: Arc, user_store: ModelHandle, @@ -216,6 +247,7 @@ impl Project { remote_id_rx, _maintain_remote_id_task, }, + opened_buffer: broadcast::channel(1).0, subscriptions: Vec::new(), active_entry: None, languages, @@ -254,70 +286,19 @@ impl Project { load_task.detach(); } - let user_ids = response - .collaborators - .iter() - .map(|peer| peer.user_id) - .collect(); - user_store - .update(cx, |user_store, cx| user_store.load_users(user_ids, cx)) - .await?; - let mut collaborators = HashMap::default(); - for message in response.collaborators { - let collaborator = Collaborator::from_proto(message, &user_store, cx).await?; - collaborators.insert(collaborator.peer_id, collaborator); - } - - Ok(cx.add_model(|cx| { + let this = cx.add_model(|cx| { let mut this = Self { worktrees: Vec::new(), open_buffers: Default::default(), loading_buffers: Default::default(), + opened_buffer: broadcast::channel(1).0, shared_buffers: Default::default(), active_entry: None, - collaborators, + collaborators: Default::default(), languages, - user_store, + user_store: user_store.clone(), fs, - subscriptions: vec![ - client.add_entity_message_handler(remote_id, cx, Self::handle_unshare_project), - client.add_entity_message_handler(remote_id, cx, Self::handle_add_collaborator), - client.add_entity_message_handler( - remote_id, - cx, - Self::handle_remove_collaborator, - ), - client.add_entity_message_handler(remote_id, cx, Self::handle_share_worktree), - client.add_entity_message_handler( - remote_id, - cx, - Self::handle_unregister_worktree, - ), - client.add_entity_message_handler(remote_id, cx, Self::handle_update_worktree), - client.add_entity_message_handler( - remote_id, - cx, - Self::handle_update_diagnostic_summary, - ), - client.add_entity_message_handler( - remote_id, - cx, - Self::handle_disk_based_diagnostics_updating, - ), - client.add_entity_message_handler( - remote_id, - cx, - Self::handle_disk_based_diagnostics_updated, - ), - client.add_entity_message_handler(remote_id, cx, Self::handle_update_buffer), - client.add_entity_message_handler( - remote_id, - cx, - Self::handle_update_buffer_file, - ), - client.add_entity_message_handler(remote_id, cx, Self::handle_buffer_reloaded), - client.add_entity_message_handler(remote_id, cx, Self::handle_buffer_saved), - ], + subscriptions: vec![client.add_model_for_remote_entity(remote_id, cx)], client, client_state: ProjectClientState::Remote { sharing_has_stopped: false, @@ -331,7 +312,27 @@ impl Project { this.add_worktree(&worktree, cx); } this - })) + }); + + let user_ids = response + .collaborators + .iter() + .map(|peer| peer.user_id) + .collect(); + user_store + .update(cx, |user_store, cx| user_store.load_users(user_ids, cx)) + .await?; + let mut collaborators = HashMap::default(); + for message in response.collaborators { + let collaborator = Collaborator::from_proto(message, &user_store, cx).await?; + collaborators.insert(collaborator.peer_id, collaborator); + } + + this.update(cx, |this, _| { + this.collaborators = collaborators; + }); + + Ok(this) } #[cfg(any(test, feature = "test-support"))] @@ -343,6 +344,25 @@ impl Project { cx.update(|cx| Project::local(client, user_store, languages, fs, cx)) } + #[cfg(any(test, feature = "test-support"))] + pub fn shared_buffer(&self, peer_id: PeerId, remote_id: u64) -> Option> { + self.shared_buffers + .get(&peer_id) + .and_then(|buffers| buffers.get(&remote_id)) + .cloned() + } + + #[cfg(any(test, feature = "test-support"))] + pub fn has_buffered_operations(&self) -> bool { + self.open_buffers + .values() + .any(|buffer| matches!(buffer, OpenBuffer::Loading(_))) + } + + pub fn fs(&self) -> &Arc { + &self.fs + } + fn set_remote_id(&mut self, remote_id: Option, cx: &mut ModelContext) { if let ProjectClientState::Local { remote_id_tx, .. } = &mut self.client_state { *remote_id_tx.borrow_mut() = remote_id; @@ -350,27 +370,8 @@ impl Project { self.subscriptions.clear(); if let Some(remote_id) = remote_id { - let client = &self.client; - self.subscriptions.extend([ - client.add_entity_request_handler(remote_id, cx, Self::handle_open_buffer), - client.add_entity_message_handler(remote_id, cx, Self::handle_close_buffer), - client.add_entity_message_handler(remote_id, cx, Self::handle_add_collaborator), - client.add_entity_message_handler(remote_id, cx, Self::handle_remove_collaborator), - client.add_entity_message_handler(remote_id, cx, Self::handle_update_worktree), - client.add_entity_message_handler(remote_id, cx, Self::handle_update_buffer), - client.add_entity_request_handler(remote_id, cx, Self::handle_save_buffer), - client.add_entity_message_handler(remote_id, cx, Self::handle_buffer_saved), - client.add_entity_request_handler(remote_id, cx, Self::handle_format_buffers), - client.add_entity_request_handler(remote_id, cx, Self::handle_get_completions), - client.add_entity_request_handler( - remote_id, - cx, - Self::handle_apply_additional_edits_for_completion, - ), - client.add_entity_request_handler(remote_id, cx, Self::handle_get_code_actions), - client.add_entity_request_handler(remote_id, cx, Self::handle_apply_code_action), - client.add_entity_request_handler(remote_id, cx, Self::handle_get_definition), - ]); + self.subscriptions + .push(self.client.add_model_for_remote_entity(remote_id, cx)); } } @@ -521,6 +522,10 @@ impl Project { } } + pub fn is_remote(&self) -> bool { + !self.is_local() + } + pub fn open_buffer( &mut self, path: impl Into, @@ -560,6 +565,11 @@ impl Project { *tx.borrow_mut() = Some(this.update(&mut cx, |this, _| { // Record the fact that the buffer is no longer loading. this.loading_buffers.remove(&project_path); + if this.loading_buffers.is_empty() { + this.open_buffers + .retain(|_, buffer| matches!(buffer, OpenBuffer::Loaded(_))) + } + let buffer = load_result.map_err(Arc::new)?; Ok(buffer) })); @@ -626,6 +636,7 @@ impl Project { .await?; let buffer = response.buffer.ok_or_else(|| anyhow!("missing buffer"))?; this.update(&mut cx, |this, cx| this.deserialize_buffer(buffer, cx)) + .await }) } @@ -737,12 +748,15 @@ impl Project { worktree: Option<&ModelHandle>, cx: &mut ModelContext, ) -> Result<()> { - if self - .open_buffers - .insert(buffer.read(cx).remote_id() as usize, buffer.downgrade()) - .is_some() - { - return Err(anyhow!("registered the same buffer twice")); + match self.open_buffers.insert( + buffer.read(cx).remote_id(), + OpenBuffer::Loaded(buffer.downgrade()), + ) { + None => {} + Some(OpenBuffer::Loading(operations)) => { + buffer.update(cx, |buffer, cx| buffer.apply_ops(operations, cx))? + } + Some(OpenBuffer::Loaded(_)) => Err(anyhow!("registered the same buffer twice"))?, } self.assign_language_to_buffer(&buffer, worktree, cx); Ok(()) @@ -1263,29 +1277,27 @@ impl Project { }; cx.spawn(|this, mut cx| async move { let response = client.request(request).await?; - this.update(&mut cx, |this, cx| { - let mut definitions = Vec::new(); - for definition in response.definitions { - let target_buffer = this.deserialize_buffer( - definition.buffer.ok_or_else(|| anyhow!("missing buffer"))?, - cx, - )?; - let target_start = definition - .target_start - .and_then(deserialize_anchor) - .ok_or_else(|| anyhow!("missing target start"))?; - let target_end = definition - .target_end - .and_then(deserialize_anchor) - .ok_or_else(|| anyhow!("missing target end"))?; - definitions.push(Definition { - target_buffer, - target_range: target_start..target_end, - }) - } + let mut definitions = Vec::new(); + for definition in response.definitions { + let buffer = definition.buffer.ok_or_else(|| anyhow!("missing buffer"))?; + let target_buffer = this + .update(&mut cx, |this, cx| this.deserialize_buffer(buffer, cx)) + .await?; + let target_start = definition + .target_start + .and_then(deserialize_anchor) + .ok_or_else(|| anyhow!("missing target start"))?; + let target_end = definition + .target_end + .and_then(deserialize_anchor) + .ok_or_else(|| anyhow!("missing target end"))?; + definitions.push(Definition { + target_buffer, + target_range: target_start..target_end, + }) + } - Ok(definitions) - }) + Ok(definitions) }) } else { Task::ready(Ok(Default::default())) @@ -1324,18 +1336,19 @@ impl Project { cx.spawn(|_, cx| async move { let completions = lang_server - .request::(lsp::CompletionParams { - text_document_position: lsp::TextDocumentPositionParams::new( - lsp::TextDocumentIdentifier::new( - lsp::Url::from_file_path(buffer_abs_path).unwrap(), + .request::(lsp::CompletionParams { + text_document_position: lsp::TextDocumentPositionParams::new( + lsp::TextDocumentIdentifier::new( + lsp::Url::from_file_path(buffer_abs_path).unwrap(), + ), + position.to_lsp_position(), ), - position.to_lsp_position(), - ), - context: Default::default(), - work_done_progress_params: Default::default(), - partial_result_params: Default::default(), - }) - .await?; + context: Default::default(), + work_done_progress_params: Default::default(), + partial_result_params: Default::default(), + }) + .await + .context("lsp completion request failed")?; let completions = if let Some(completions) = completions { match completions { @@ -1347,41 +1360,56 @@ impl Project { }; source_buffer_handle.read_with(&cx, |this, _| { - Ok(completions.into_iter().filter_map(|lsp_completion| { - let (old_range, new_text) = match lsp_completion.text_edit.as_ref()? { - lsp::CompletionTextEdit::Edit(edit) => (range_from_lsp(edit.range), edit.new_text.clone()), - lsp::CompletionTextEdit::InsertAndReplace(_) => { - log::info!("received an insert and replace completion but we don't yet support that"); - return None - }, - }; - - let clipped_start = this.clip_point_utf16(old_range.start, Bias::Left); - let clipped_end = this.clip_point_utf16(old_range.end, Bias::Left) ; - if clipped_start == old_range.start && clipped_end == old_range.end { - Some(Completion { - old_range: this.anchor_before(old_range.start)..this.anchor_after(old_range.end), - new_text, - label: language.as_ref().and_then(|l| l.label_for_completion(&lsp_completion)).unwrap_or_else(|| CompletionLabel::plain(&lsp_completion)), - lsp_completion, - }) - } else { - None - } - }).collect()) + Ok(completions + .into_iter() + .filter_map(|lsp_completion| { + let (old_range, new_text) = match lsp_completion.text_edit.as_ref()? { + lsp::CompletionTextEdit::Edit(edit) => { + (range_from_lsp(edit.range), edit.new_text.clone()) + } + lsp::CompletionTextEdit::InsertAndReplace(_) => { + log::info!("unsupported insert/replace completion"); + return None; + } + }; + + let clipped_start = this.clip_point_utf16(old_range.start, Bias::Left); + let clipped_end = this.clip_point_utf16(old_range.end, Bias::Left); + if clipped_start == old_range.start && clipped_end == old_range.end { + Some(Completion { + old_range: this.anchor_before(old_range.start) + ..this.anchor_after(old_range.end), + new_text, + label: language + .as_ref() + .and_then(|l| l.label_for_completion(&lsp_completion)) + .unwrap_or_else(|| CompletionLabel::plain(&lsp_completion)), + lsp_completion, + }) + } else { + None + } + }) + .collect()) }) - }) } else if let Some(project_id) = self.remote_id() { let rpc = self.client.clone(); - cx.foreground().spawn(async move { - let response = rpc - .request(proto::GetCompletions { - project_id, - buffer_id, - position: Some(language::proto::serialize_anchor(&anchor)), + let message = proto::GetCompletions { + project_id, + buffer_id, + position: Some(language::proto::serialize_anchor(&anchor)), + version: (&source_buffer.version()).into(), + }; + cx.spawn_weak(|_, mut cx| async move { + let response = rpc.request(message).await?; + + source_buffer_handle + .update(&mut cx, |buffer, _| { + buffer.wait_for_version(response.version.into()) }) - .await?; + .await; + response .completions .into_iter() @@ -1550,7 +1578,7 @@ impl Project { }) } else if let Some(project_id) = self.remote_id() { let rpc = self.client.clone(); - cx.foreground().spawn(async move { + cx.spawn_weak(|_, mut cx| async move { let response = rpc .request(proto::GetCodeActions { project_id, @@ -1559,6 +1587,13 @@ impl Project { end: Some(language::proto::serialize_anchor(&range.end)), }) .await?; + + buffer_handle + .update(&mut cx, |buffer, _| { + buffer.wait_for_version(response.version.into()) + }) + .await; + response .actions .into_iter() @@ -2124,9 +2159,9 @@ impl Project { this.update(&mut cx, |this, cx| { let worktree_id = WorktreeId::from_proto(envelope.payload.worktree_id); if let Some(worktree) = this.worktree_for_id(worktree_id, cx) { - worktree.update(cx, |worktree, cx| { + worktree.update(cx, |worktree, _| { let worktree = worktree.as_remote_mut().unwrap(); - worktree.update_from_remote(envelope, cx) + worktree.update_from_remote(envelope) })?; } Ok(()) @@ -2188,15 +2223,26 @@ impl Project { ) -> Result<()> { this.update(&mut cx, |this, cx| { let payload = envelope.payload.clone(); - let buffer_id = payload.buffer_id as usize; + let buffer_id = payload.buffer_id; let ops = payload .operations .into_iter() .map(|op| language::proto::deserialize_operation(op)) .collect::, _>>()?; - if let Some(buffer) = this.open_buffers.get_mut(&buffer_id) { - if let Some(buffer) = buffer.upgrade(cx) { - buffer.update(cx, |buffer, cx| buffer.apply_ops(ops, cx))?; + let is_remote = this.is_remote(); + match this.open_buffers.entry(buffer_id) { + hash_map::Entry::Occupied(mut e) => match e.get_mut() { + OpenBuffer::Loaded(buffer) => { + if let Some(buffer) = buffer.upgrade(cx) { + buffer.update(cx, |buffer, cx| buffer.apply_ops(ops, cx))?; + } + } + OpenBuffer::Loading(operations) => operations.extend_from_slice(&ops), + }, + hash_map::Entry::Vacant(e) => { + if is_remote && this.loading_buffers.len() > 0 { + e.insert(OpenBuffer::Loading(ops)); + } } } Ok(()) @@ -2211,7 +2257,7 @@ impl Project { ) -> Result<()> { this.update(&mut cx, |this, cx| { let payload = envelope.payload.clone(); - let buffer_id = payload.buffer_id as usize; + let buffer_id = payload.buffer_id; let file = payload.file.ok_or_else(|| anyhow!("invalid file"))?; let worktree = this .worktree_for_id(WorktreeId::from_proto(file.worktree_id), cx) @@ -2237,21 +2283,30 @@ impl Project { ) -> Result { let buffer_id = envelope.payload.buffer_id; let sender_id = envelope.original_sender_id()?; - let (project_id, save) = this.update(&mut cx, |this, cx| { + let requested_version = envelope.payload.version.try_into()?; + + let (project_id, buffer) = this.update(&mut cx, |this, _| { let project_id = this.remote_id().ok_or_else(|| anyhow!("not connected"))?; let buffer = this .shared_buffers .get(&sender_id) .and_then(|shared_buffers| shared_buffers.get(&buffer_id).cloned()) .ok_or_else(|| anyhow!("unknown buffer id {}", buffer_id))?; - Ok::<_, anyhow::Error>((project_id, buffer.update(cx, |buffer, cx| buffer.save(cx)))) + Ok::<_, anyhow::Error>((project_id, buffer)) })?; - let (version, mtime) = save.await?; + if !buffer + .read_with(&cx, |buffer, _| buffer.version()) + .observed_all(&requested_version) + { + Err(anyhow!("save request depends on unreceived edits"))?; + } + + let (saved_version, mtime) = buffer.update(&mut cx, |buffer, cx| buffer.save(cx)).await?; Ok(proto::BufferSaved { project_id, buffer_id, - version: (&version).into(), + version: (&saved_version).into(), mtime: Some(mtime.into()), }) } @@ -2301,21 +2356,30 @@ impl Project { .position .and_then(language::proto::deserialize_anchor) .ok_or_else(|| anyhow!("invalid position"))?; - let completions = this.update(&mut cx, |this, cx| { - let buffer = this - .shared_buffers + let version = clock::Global::from(envelope.payload.version); + let buffer = this.read_with(&cx, |this, _| { + this.shared_buffers .get(&sender_id) .and_then(|shared_buffers| shared_buffers.get(&envelope.payload.buffer_id).cloned()) - .ok_or_else(|| anyhow!("unknown buffer id {}", envelope.payload.buffer_id))?; - Ok::<_, anyhow::Error>(this.completions(&buffer, position, cx)) + .ok_or_else(|| anyhow!("unknown buffer id {}", envelope.payload.buffer_id)) })?; + if !buffer + .read_with(&cx, |buffer, _| buffer.version()) + .observed_all(&version) + { + Err(anyhow!("completion request depends on unreceived edits"))?; + } + let version = buffer.read_with(&cx, |buffer, _| buffer.version()); + let completions = this + .update(&mut cx, |this, cx| this.completions(&buffer, position, cx)) + .await?; Ok(proto::GetCompletionsResponse { completions: completions - .await? .iter() .map(language::proto::serialize_completion) .collect(), + version: (&version).into(), }) } @@ -2370,12 +2434,17 @@ impl Project { .end .and_then(language::proto::deserialize_anchor) .ok_or_else(|| anyhow!("invalid end"))?; - let code_actions = this.update(&mut cx, |this, cx| { - let buffer = this - .shared_buffers + let buffer = this.update(&mut cx, |this, _| { + this.shared_buffers .get(&sender_id) .and_then(|shared_buffers| shared_buffers.get(&envelope.payload.buffer_id).cloned()) - .ok_or_else(|| anyhow!("unknown buffer id {}", envelope.payload.buffer_id))?; + .ok_or_else(|| anyhow!("unknown buffer id {}", envelope.payload.buffer_id)) + })?; + let version = buffer.read_with(&cx, |buffer, _| buffer.version()); + if !version.observed(start.timestamp) || !version.observed(end.timestamp) { + Err(anyhow!("code action request references unreceived edits"))?; + } + let code_actions = this.update(&mut cx, |this, cx| { Ok::<_, anyhow::Error>(this.code_actions(&buffer, start..end, cx)) })?; @@ -2385,6 +2454,7 @@ impl Project { .iter() .map(language::proto::serialize_code_action) .collect(), + version: (&version).into(), }) } @@ -2516,20 +2586,15 @@ impl Project { push_to_history: bool, cx: &mut ModelContext, ) -> Task> { - let mut project_transaction = ProjectTransaction::default(); - for (buffer, transaction) in message.buffers.into_iter().zip(message.transactions) { - let buffer = match self.deserialize_buffer(buffer, cx) { - Ok(buffer) => buffer, - Err(error) => return Task::ready(Err(error)), - }; - let transaction = match language::proto::deserialize_transaction(transaction) { - Ok(transaction) => transaction, - Err(error) => return Task::ready(Err(error)), - }; - project_transaction.0.insert(buffer, transaction); - } - - cx.spawn_weak(|_, mut cx| async move { + cx.spawn(|this, mut cx| async move { + let mut project_transaction = ProjectTransaction::default(); + for (buffer, transaction) in message.buffers.into_iter().zip(message.transactions) { + let buffer = this + .update(&mut cx, |this, cx| this.deserialize_buffer(buffer, cx)) + .await?; + let transaction = language::proto::deserialize_transaction(transaction)?; + project_transaction.0.insert(buffer, transaction); + } for (buffer, transaction) in &project_transaction.0 { buffer .update(&mut cx, |buffer, _| { @@ -2573,33 +2638,60 @@ impl Project { &mut self, buffer: proto::Buffer, cx: &mut ModelContext, - ) -> Result> { - match buffer.variant.ok_or_else(|| anyhow!("missing buffer"))? { - proto::buffer::Variant::Id(id) => self - .open_buffers - .get(&(id as usize)) - .and_then(|buffer| buffer.upgrade(cx)) - .ok_or_else(|| anyhow!("no buffer exists for id {}", id)), - proto::buffer::Variant::State(mut buffer) => { - let mut buffer_worktree = None; - let mut buffer_file = None; - if let Some(file) = buffer.file.take() { - let worktree_id = WorktreeId::from_proto(file.worktree_id); - let worktree = self - .worktree_for_id(worktree_id, cx) - .ok_or_else(|| anyhow!("no worktree found for id {}", file.worktree_id))?; - buffer_file = Some(Box::new(File::from_proto(file, worktree.clone(), cx)?) - as Box); - buffer_worktree = Some(worktree); + ) -> Task>> { + let replica_id = self.replica_id(); + + let mut opened_buffer_tx = self.opened_buffer.clone(); + let mut opened_buffer_rx = self.opened_buffer.subscribe(); + cx.spawn(|this, mut cx| async move { + match buffer.variant.ok_or_else(|| anyhow!("missing buffer"))? { + proto::buffer::Variant::Id(id) => { + let buffer = loop { + let buffer = this.read_with(&cx, |this, cx| { + this.open_buffers + .get(&id) + .and_then(|buffer| buffer.upgrade(cx)) + }); + if let Some(buffer) = buffer { + break buffer; + } + opened_buffer_rx + .recv() + .await + .ok_or_else(|| anyhow!("project dropped while waiting for buffer"))?; + }; + Ok(buffer) } + proto::buffer::Variant::State(mut buffer) => { + let mut buffer_worktree = None; + let mut buffer_file = None; + if let Some(file) = buffer.file.take() { + this.read_with(&cx, |this, cx| { + let worktree_id = WorktreeId::from_proto(file.worktree_id); + let worktree = + this.worktree_for_id(worktree_id, cx).ok_or_else(|| { + anyhow!("no worktree found for id {}", file.worktree_id) + })?; + buffer_file = + Some(Box::new(File::from_proto(file, worktree.clone(), cx)?) + as Box); + buffer_worktree = Some(worktree); + Ok::<_, anyhow::Error>(()) + })?; + } - let buffer = cx.add_model(|cx| { - Buffer::from_proto(self.replica_id(), buffer, buffer_file, cx).unwrap() - }); - self.register_buffer(&buffer, buffer_worktree.as_ref(), cx)?; - Ok(buffer) + let buffer = cx.add_model(|cx| { + Buffer::from_proto(replica_id, buffer, buffer_file, cx).unwrap() + }); + this.update(&mut cx, |this, cx| { + this.register_buffer(&buffer, buffer_worktree.as_ref(), cx) + })?; + + let _ = opened_buffer_tx.send(()).await; + Ok(buffer) + } } - } + }) } async fn handle_close_buffer( @@ -2635,7 +2727,7 @@ impl Project { this.update(&mut cx, |this, cx| { let buffer = this .open_buffers - .get(&(envelope.payload.buffer_id as usize)) + .get(&envelope.payload.buffer_id) .and_then(|buffer| buffer.upgrade(cx)); if let Some(buffer) = buffer { buffer.update(cx, |buffer, cx| { @@ -2661,7 +2753,7 @@ impl Project { this.update(&mut cx, |this, cx| { let buffer = this .open_buffers - .get(&(payload.buffer_id as usize)) + .get(&payload.buffer_id) .and_then(|buffer| buffer.upgrade(cx)); if let Some(buffer) = buffer { buffer.update(cx, |buffer, cx| { @@ -2719,6 +2811,15 @@ impl WorktreeHandle { } } +impl OpenBuffer { + pub fn upgrade(&self, cx: &impl UpgradeModelHandle) -> Option> { + match self { + OpenBuffer::Loaded(handle) => handle.upgrade(cx), + OpenBuffer::Loading(_) => None, + } + } +} + struct CandidateSet { snapshot: Snapshot, include_ignored: bool, @@ -2959,7 +3060,7 @@ mod tests { #[gpui::test] async fn test_language_server_diagnostics(mut cx: gpui::TestAppContext) { - let (language_server_config, mut fake_server) = LanguageServerConfig::fake(&cx).await; + let (language_server_config, mut fake_servers) = LanguageServerConfig::fake(); let progress_token = language_server_config .disk_based_diagnostics_progress_token .clone() @@ -3022,6 +3123,7 @@ mod tests { let mut events = subscribe(&project, &mut cx); + let mut fake_server = fake_servers.next().await.unwrap(); fake_server.start_progress(&progress_token).await; assert_eq!( events.next().await.unwrap(), @@ -3123,7 +3225,7 @@ mod tests { #[gpui::test] async fn test_definition(mut cx: gpui::TestAppContext) { - let (language_server_config, mut fake_server) = LanguageServerConfig::fake(&cx).await; + let (language_server_config, mut fake_servers) = LanguageServerConfig::fake(); let mut languages = LanguageRegistry::new(); languages.add(Arc::new(Language::new( @@ -3178,6 +3280,7 @@ mod tests { .await .unwrap(); + let mut fake_server = fake_servers.next().await.unwrap(); fake_server.handle_request::(move |params| { let params = params.text_document_position_params; assert_eq!( @@ -3371,7 +3474,7 @@ mod tests { .await; // Create a remote copy of this worktree. - let initial_snapshot = tree.read_with(&cx, |tree, _| tree.snapshot()); + let initial_snapshot = tree.read_with(&cx, |tree, _| tree.as_local().unwrap().snapshot()); let (remote, load_task) = cx.update(|cx| { Worktree::remote( 1, @@ -3447,10 +3550,13 @@ mod tests { // Update the remote worktree. Check that it becomes consistent with the // local worktree. remote.update(&mut cx, |remote, cx| { - let update_message = - tree.read(cx) - .snapshot() - .build_update(&initial_snapshot, 1, 1, true); + let update_message = tree.read(cx).as_local().unwrap().snapshot().build_update( + &initial_snapshot, + 1, + 1, + 0, + true, + ); remote .as_remote_mut() .unwrap() diff --git a/crates/project/src/worktree.rs b/crates/project/src/worktree.rs index 79e3a7e528b0b310978d46cb26f64f5bf2be546d..074781449db41204a4957da0385ccb491ba2383b 100644 --- a/crates/project/src/worktree.rs +++ b/crates/project/src/worktree.rs @@ -7,8 +7,11 @@ use ::ignore::gitignore::{Gitignore, GitignoreBuilder}; use anyhow::{anyhow, Result}; use client::{proto, Client, TypedEnvelope}; use clock::ReplicaId; -use collections::HashMap; -use futures::{Stream, StreamExt}; +use collections::{HashMap, VecDeque}; +use futures::{ + channel::mpsc::{self, UnboundedSender}, + Stream, StreamExt, +}; use fuzzy::CharBag; use gpui::{ executor, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, @@ -18,6 +21,7 @@ use language::{Buffer, DiagnosticEntry, Operation, PointUtf16, Rope}; use lazy_static::lazy_static; use parking_lot::Mutex; use postage::{ + oneshot, prelude::{Sink as _, Stream as _}, watch, }; @@ -75,11 +79,13 @@ pub struct RemoteWorktree { project_id: u64, snapshot_rx: watch::Receiver, client: Arc, - updates_tx: postage::mpsc::Sender, + updates_tx: UnboundedSender, replica_id: ReplicaId, queued_operations: Vec<(u64, Operation)>, diagnostic_summaries: TreeMap, weak: bool, + next_update_id: u64, + pending_updates: VecDeque, } #[derive(Clone)] @@ -208,7 +214,7 @@ impl Worktree { entries_by_id: Default::default(), }; - let (updates_tx, mut updates_rx) = postage::mpsc::channel(64); + let (updates_tx, mut updates_rx) = mpsc::unbounded(); let (mut snapshot_tx, snapshot_rx) = watch::channel_with(snapshot.clone()); let worktree_handle = cx.add_model(|_: &mut ModelContext| { Worktree::Remote(RemoteWorktree { @@ -233,6 +239,8 @@ impl Worktree { }), ), weak, + next_update_id: worktree.next_update_id, + pending_updates: Default::default(), }) }); @@ -276,7 +284,7 @@ impl Worktree { cx.background() .spawn(async move { - while let Some(update) = updates_rx.recv().await { + while let Some(update) = updates_rx.next().await { let mut snapshot = snapshot_tx.borrow().clone(); if let Err(error) = snapshot.apply_remote_update(update) { log::error!("error applying worktree update: {}", error); @@ -450,7 +458,7 @@ impl LocalWorktree { weak: bool, fs: Arc, cx: &mut AsyncAppContext, - ) -> Result<(ModelHandle, Sender)> { + ) -> Result<(ModelHandle, UnboundedSender)> { let abs_path = path.into(); let path: Arc = Arc::from(Path::new("")); let next_entry_id = AtomicUsize::new(0); @@ -470,7 +478,7 @@ impl LocalWorktree { } } - let (scan_states_tx, scan_states_rx) = smol::channel::unbounded(); + let (scan_states_tx, mut scan_states_rx) = mpsc::unbounded(); let (mut last_scan_state_tx, last_scan_state_rx) = watch::channel_with(ScanState::Scanning); let tree = cx.add_model(move |cx: &mut ModelContext| { let mut snapshot = LocalSnapshot { @@ -515,7 +523,7 @@ impl LocalWorktree { }; cx.spawn_weak(|this, mut cx| async move { - while let Ok(scan_state) = scan_states_rx.recv().await { + while let Some(scan_state) = scan_states_rx.next().await { if let Some(handle) = this.upgrade(&cx) { let to_send = handle.update(&mut cx, |this, cx| { last_scan_state_tx.blocking_send(scan_state).ok(); @@ -761,16 +769,41 @@ impl LocalWorktree { let worktree_id = cx.model_id() as u64; let (snapshots_to_send_tx, snapshots_to_send_rx) = smol::channel::unbounded::(); + let (mut share_tx, mut share_rx) = oneshot::channel(); let maintain_remote_snapshot = cx.background().spawn({ let rpc = rpc.clone(); let snapshot = snapshot.clone(); + let diagnostic_summaries = self.diagnostic_summaries.clone(); + let weak = self.weak; async move { + if let Err(error) = rpc + .request(proto::ShareWorktree { + project_id, + worktree: Some(snapshot.to_proto(&diagnostic_summaries, weak)), + }) + .await + { + let _ = share_tx.try_send(Err(error)); + return; + } else { + let _ = share_tx.try_send(Ok(())); + } + + let mut update_id = 0; let mut prev_snapshot = snapshot; while let Ok(snapshot) = snapshots_to_send_rx.recv().await { - let message = - snapshot.build_update(&prev_snapshot, project_id, worktree_id, false); - match rpc.send(message) { - Ok(()) => prev_snapshot = snapshot, + let message = snapshot.build_update( + &prev_snapshot, + project_id, + worktree_id, + update_id, + false, + ); + match rpc.request(message).await { + Ok(_) => { + prev_snapshot = snapshot; + update_id += 1; + } Err(err) => log::error!("error sending snapshot diff {}", err), } } @@ -782,18 +815,11 @@ impl LocalWorktree { _maintain_remote_snapshot: Some(maintain_remote_snapshot), }); - let diagnostic_summaries = self.diagnostic_summaries.clone(); - let weak = self.weak; - let share_message = cx.background().spawn(async move { - proto::ShareWorktree { - project_id, - worktree: Some(snapshot.to_proto(&diagnostic_summaries, weak)), - } - }); - cx.foreground().spawn(async move { - rpc.request(share_message.await).await?; - Ok(()) + match share_rx.next().await { + Some(result) => result, + None => Err(anyhow!("unshared before sharing completed")), + } }) } @@ -814,19 +840,39 @@ impl RemoteWorktree { pub fn update_from_remote( &mut self, envelope: TypedEnvelope, - cx: &mut ModelContext, ) -> Result<()> { - let mut tx = self.updates_tx.clone(); - let payload = envelope.payload.clone(); - cx.foreground() - .spawn(async move { - tx.send(payload).await.expect("receiver runs to completion"); - }) - .detach(); + let update = envelope.payload; + if update.id > self.next_update_id { + let ix = match self + .pending_updates + .binary_search_by_key(&update.id, |pending| pending.id) + { + Ok(ix) | Err(ix) => ix, + }; + self.pending_updates.insert(ix, update); + } else { + let tx = self.updates_tx.clone(); + self.next_update_id += 1; + tx.unbounded_send(update) + .expect("consumer runs to completion"); + while let Some(update) = self.pending_updates.front() { + if update.id == self.next_update_id { + self.next_update_id += 1; + tx.unbounded_send(self.pending_updates.pop_front().unwrap()) + .expect("consumer runs to completion"); + } else { + break; + } + } + } Ok(()) } + pub fn has_pending_updates(&self) -> bool { + !self.pending_updates.is_empty() + } + pub fn update_diagnostic_summary( &mut self, path: Arc, @@ -849,94 +895,6 @@ impl Snapshot { self.id } - pub(crate) fn to_proto( - &self, - diagnostic_summaries: &TreeMap, - weak: bool, - ) -> proto::Worktree { - let root_name = self.root_name.clone(); - proto::Worktree { - id: self.id.0 as u64, - root_name, - entries: self - .entries_by_path - .iter() - .filter(|e| !e.is_ignored) - .map(Into::into) - .collect(), - diagnostic_summaries: diagnostic_summaries - .iter() - .map(|(path, summary)| summary.to_proto(path.0.clone())) - .collect(), - weak, - } - } - - pub(crate) fn build_update( - &self, - other: &Self, - project_id: u64, - worktree_id: u64, - include_ignored: bool, - ) -> proto::UpdateWorktree { - let mut updated_entries = Vec::new(); - let mut removed_entries = Vec::new(); - let mut self_entries = self - .entries_by_id - .cursor::<()>() - .filter(|e| include_ignored || !e.is_ignored) - .peekable(); - let mut other_entries = other - .entries_by_id - .cursor::<()>() - .filter(|e| include_ignored || !e.is_ignored) - .peekable(); - loop { - match (self_entries.peek(), other_entries.peek()) { - (Some(self_entry), Some(other_entry)) => { - match Ord::cmp(&self_entry.id, &other_entry.id) { - Ordering::Less => { - let entry = self.entry_for_id(self_entry.id).unwrap().into(); - updated_entries.push(entry); - self_entries.next(); - } - Ordering::Equal => { - if self_entry.scan_id != other_entry.scan_id { - let entry = self.entry_for_id(self_entry.id).unwrap().into(); - updated_entries.push(entry); - } - - self_entries.next(); - other_entries.next(); - } - Ordering::Greater => { - removed_entries.push(other_entry.id as u64); - other_entries.next(); - } - } - } - (Some(self_entry), None) => { - let entry = self.entry_for_id(self_entry.id).unwrap().into(); - updated_entries.push(entry); - self_entries.next(); - } - (None, Some(other_entry)) => { - removed_entries.push(other_entry.id as u64); - other_entries.next(); - } - (None, None) => break, - } - } - - proto::UpdateWorktree { - project_id, - worktree_id, - root_name: self.root_name().to_string(), - updated_entries, - removed_entries, - } - } - pub(crate) fn apply_remote_update(&mut self, update: proto::UpdateWorktree) -> Result<()> { let mut entries_by_path_edits = Vec::new(); let mut entries_by_id_edits = Vec::new(); @@ -1077,6 +1035,97 @@ impl Snapshot { } impl LocalSnapshot { + pub(crate) fn to_proto( + &self, + diagnostic_summaries: &TreeMap, + weak: bool, + ) -> proto::Worktree { + let root_name = self.root_name.clone(); + proto::Worktree { + id: self.id.0 as u64, + root_name, + entries: self + .entries_by_path + .iter() + .filter(|e| !e.is_ignored) + .map(Into::into) + .collect(), + diagnostic_summaries: diagnostic_summaries + .iter() + .map(|(path, summary)| summary.to_proto(path.0.clone())) + .collect(), + weak, + next_update_id: 0, + } + } + + pub(crate) fn build_update( + &self, + other: &Self, + project_id: u64, + worktree_id: u64, + update_id: u64, + include_ignored: bool, + ) -> proto::UpdateWorktree { + let mut updated_entries = Vec::new(); + let mut removed_entries = Vec::new(); + let mut self_entries = self + .entries_by_id + .cursor::<()>() + .filter(|e| include_ignored || !e.is_ignored) + .peekable(); + let mut other_entries = other + .entries_by_id + .cursor::<()>() + .filter(|e| include_ignored || !e.is_ignored) + .peekable(); + loop { + match (self_entries.peek(), other_entries.peek()) { + (Some(self_entry), Some(other_entry)) => { + match Ord::cmp(&self_entry.id, &other_entry.id) { + Ordering::Less => { + let entry = self.entry_for_id(self_entry.id).unwrap().into(); + updated_entries.push(entry); + self_entries.next(); + } + Ordering::Equal => { + if self_entry.scan_id != other_entry.scan_id { + let entry = self.entry_for_id(self_entry.id).unwrap().into(); + updated_entries.push(entry); + } + + self_entries.next(); + other_entries.next(); + } + Ordering::Greater => { + removed_entries.push(other_entry.id as u64); + other_entries.next(); + } + } + } + (Some(self_entry), None) => { + let entry = self.entry_for_id(self_entry.id).unwrap().into(); + updated_entries.push(entry); + self_entries.next(); + } + (None, Some(other_entry)) => { + removed_entries.push(other_entry.id as u64); + other_entries.next(); + } + (None, None) => break, + } + } + + proto::UpdateWorktree { + id: update_id as u64, + project_id, + worktree_id, + root_name: self.root_name().to_string(), + updated_entries, + removed_entries, + } + } + fn insert_entry(&mut self, mut entry: Entry, fs: &dyn Fs) -> Entry { if !entry.is_dir() && entry.path.file_name() == Some(&GITIGNORE) { let abs_path = self.abs_path.join(&entry.path); @@ -1283,13 +1332,29 @@ impl fmt::Debug for LocalWorktree { impl fmt::Debug for Snapshot { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - for entry in self.entries_by_path.cursor::<()>() { - for _ in entry.path.ancestors().skip(1) { - write!(f, " ")?; + struct EntriesById<'a>(&'a SumTree); + struct EntriesByPath<'a>(&'a SumTree); + + impl<'a> fmt::Debug for EntriesByPath<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_map() + .entries(self.0.iter().map(|entry| (&entry.path, entry.id))) + .finish() } - writeln!(f, "{:?} (inode: {})", entry.path, entry.inode)?; } - Ok(()) + + impl<'a> fmt::Debug for EntriesById<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_list().entries(self.0.iter()).finish() + } + } + + f.debug_struct("Snapshot") + .field("id", &self.id) + .field("root_name", &self.root_name) + .field("entries_by_path", &EntriesByPath(&self.entries_by_path)) + .field("entries_by_id", &EntriesById(&self.entries_by_id)) + .finish() } } @@ -1322,7 +1387,9 @@ impl language::File for File { fn full_path(&self, cx: &AppContext) -> PathBuf { let mut full_path = PathBuf::new(); full_path.push(self.worktree.read(cx).root_name()); - full_path.push(&self.path); + if self.path.components().next().is_some() { + full_path.push(&self.path); + } full_path } @@ -1372,6 +1439,7 @@ impl language::File for File { .request(proto::SaveBuffer { project_id, buffer_id, + version: (&version).into(), }) .await?; let version = response.version.try_into()?; @@ -1493,7 +1561,7 @@ impl File { } } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub struct Entry { pub id: usize, pub kind: EntryKind, @@ -1504,7 +1572,7 @@ pub struct Entry { pub is_ignored: bool, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, PartialEq, Eq)] pub enum EntryKind { PendingDir, Dir, @@ -1668,14 +1736,14 @@ impl<'a> sum_tree::Dimension<'a, EntrySummary> for PathKey { struct BackgroundScanner { fs: Arc, snapshot: Arc>, - notify: Sender, + notify: UnboundedSender, executor: Arc, } impl BackgroundScanner { fn new( snapshot: Arc>, - notify: Sender, + notify: UnboundedSender, fs: Arc, executor: Arc, ) -> Self { @@ -1696,28 +1764,27 @@ impl BackgroundScanner { } async fn run(mut self, events_rx: impl Stream>) { - if self.notify.send(ScanState::Scanning).await.is_err() { + if self.notify.unbounded_send(ScanState::Scanning).is_err() { return; } if let Err(err) = self.scan_dirs().await { if self .notify - .send(ScanState::Err(Arc::new(err))) - .await + .unbounded_send(ScanState::Err(Arc::new(err))) .is_err() { return; } } - if self.notify.send(ScanState::Idle).await.is_err() { + if self.notify.unbounded_send(ScanState::Idle).is_err() { return; } futures::pin_mut!(events_rx); while let Some(events) = events_rx.next().await { - if self.notify.send(ScanState::Scanning).await.is_err() { + if self.notify.unbounded_send(ScanState::Scanning).is_err() { break; } @@ -1725,7 +1792,7 @@ impl BackgroundScanner { break; } - if self.notify.send(ScanState::Idle).await.is_err() { + if self.notify.unbounded_send(ScanState::Idle).is_err() { break; } } @@ -2391,7 +2458,7 @@ mod tests { fmt::Write, time::{SystemTime, UNIX_EPOCH}, }; - use util::test::temp_tree; + use util::{post_inc, test::temp_tree}; #[gpui::test] async fn test_traversal(cx: gpui::TestAppContext) { @@ -2503,7 +2570,7 @@ mod tests { } log::info!("Generated initial tree"); - let (notify_tx, _notify_rx) = smol::channel::unbounded(); + let (notify_tx, _notify_rx) = mpsc::unbounded(); let fs = Arc::new(RealFs); let next_entry_id = Arc::new(AtomicUsize::new(0)); let mut initial_snapshot = LocalSnapshot { @@ -2563,7 +2630,7 @@ mod tests { smol::block_on(scanner.process_events(events)); scanner.snapshot().check_invariants(); - let (notify_tx, _notify_rx) = smol::channel::unbounded(); + let (notify_tx, _notify_rx) = mpsc::unbounded(); let mut new_scanner = BackgroundScanner::new( Arc::new(Mutex::new(initial_snapshot)), notify_tx, @@ -2576,6 +2643,7 @@ mod tests { new_scanner.snapshot().to_vec(true) ); + let mut update_id = 0; for mut prev_snapshot in snapshots { let include_ignored = rng.gen::(); if !include_ignored { @@ -2596,9 +2664,13 @@ mod tests { prev_snapshot.entries_by_id.edit(entries_by_id_edits, &()); } - let update = scanner - .snapshot() - .build_update(&prev_snapshot, 0, 0, include_ignored); + let update = scanner.snapshot().build_update( + &prev_snapshot, + 0, + 0, + post_inc(&mut update_id), + include_ignored, + ); prev_snapshot.apply_remote_update(update).unwrap(); assert_eq!( prev_snapshot.to_vec(true), diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index cc28e507e97ab517f632bd4b303c082bf0722708..9d7baa8992443ba20b394ba813f9997164f9cf3d 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -9,62 +9,63 @@ message Envelope { Ack ack = 4; Error error = 5; Ping ping = 6; - - RegisterProject register_project = 7; - RegisterProjectResponse register_project_response = 8; - UnregisterProject unregister_project = 9; - ShareProject share_project = 10; - UnshareProject unshare_project = 11; - JoinProject join_project = 12; - JoinProjectResponse join_project_response = 13; - LeaveProject leave_project = 14; - AddProjectCollaborator add_project_collaborator = 15; - RemoveProjectCollaborator remove_project_collaborator = 16; - GetDefinition get_definition = 17; - GetDefinitionResponse get_definition_response = 18; - - RegisterWorktree register_worktree = 19; - UnregisterWorktree unregister_worktree = 20; - ShareWorktree share_worktree = 21; - UpdateWorktree update_worktree = 22; - UpdateDiagnosticSummary update_diagnostic_summary = 23; - DiskBasedDiagnosticsUpdating disk_based_diagnostics_updating = 24; - DiskBasedDiagnosticsUpdated disk_based_diagnostics_updated = 25; - - OpenBuffer open_buffer = 26; - OpenBufferResponse open_buffer_response = 27; - CloseBuffer close_buffer = 28; - UpdateBuffer update_buffer = 29; - UpdateBufferFile update_buffer_file = 30; - SaveBuffer save_buffer = 31; - BufferSaved buffer_saved = 32; - BufferReloaded buffer_reloaded = 33; - FormatBuffers format_buffers = 34; - FormatBuffersResponse format_buffers_response = 35; - GetCompletions get_completions = 36; - GetCompletionsResponse get_completions_response = 37; - ApplyCompletionAdditionalEdits apply_completion_additional_edits = 38; - ApplyCompletionAdditionalEditsResponse apply_completion_additional_edits_response = 39; - GetCodeActions get_code_actions = 40; - GetCodeActionsResponse get_code_actions_response = 41; - ApplyCodeAction apply_code_action = 42; - ApplyCodeActionResponse apply_code_action_response = 43; - - GetChannels get_channels = 44; - GetChannelsResponse get_channels_response = 45; - JoinChannel join_channel = 46; - JoinChannelResponse join_channel_response = 47; - LeaveChannel leave_channel = 48; - SendChannelMessage send_channel_message = 49; - SendChannelMessageResponse send_channel_message_response = 50; - ChannelMessageSent channel_message_sent = 51; - GetChannelMessages get_channel_messages = 52; - GetChannelMessagesResponse get_channel_messages_response = 53; - - UpdateContacts update_contacts = 54; - - GetUsers get_users = 55; - GetUsersResponse get_users_response = 56; + Test test = 7; + + RegisterProject register_project = 8; + RegisterProjectResponse register_project_response = 9; + UnregisterProject unregister_project = 10; + ShareProject share_project = 11; + UnshareProject unshare_project = 12; + JoinProject join_project = 13; + JoinProjectResponse join_project_response = 14; + LeaveProject leave_project = 15; + AddProjectCollaborator add_project_collaborator = 16; + RemoveProjectCollaborator remove_project_collaborator = 17; + GetDefinition get_definition = 18; + GetDefinitionResponse get_definition_response = 19; + + RegisterWorktree register_worktree = 20; + UnregisterWorktree unregister_worktree = 21; + ShareWorktree share_worktree = 22; + UpdateWorktree update_worktree = 23; + UpdateDiagnosticSummary update_diagnostic_summary = 24; + DiskBasedDiagnosticsUpdating disk_based_diagnostics_updating = 25; + DiskBasedDiagnosticsUpdated disk_based_diagnostics_updated = 26; + + OpenBuffer open_buffer = 27; + OpenBufferResponse open_buffer_response = 28; + CloseBuffer close_buffer = 29; + UpdateBuffer update_buffer = 30; + UpdateBufferFile update_buffer_file = 31; + SaveBuffer save_buffer = 32; + BufferSaved buffer_saved = 33; + BufferReloaded buffer_reloaded = 34; + FormatBuffers format_buffers = 35; + FormatBuffersResponse format_buffers_response = 36; + GetCompletions get_completions = 37; + GetCompletionsResponse get_completions_response = 38; + ApplyCompletionAdditionalEdits apply_completion_additional_edits = 39; + ApplyCompletionAdditionalEditsResponse apply_completion_additional_edits_response = 40; + GetCodeActions get_code_actions = 41; + GetCodeActionsResponse get_code_actions_response = 42; + ApplyCodeAction apply_code_action = 43; + ApplyCodeActionResponse apply_code_action_response = 44; + + GetChannels get_channels = 45; + GetChannelsResponse get_channels_response = 46; + JoinChannel join_channel = 47; + JoinChannelResponse join_channel_response = 48; + LeaveChannel leave_channel = 49; + SendChannelMessage send_channel_message = 50; + SendChannelMessageResponse send_channel_message_response = 51; + ChannelMessageSent channel_message_sent = 52; + GetChannelMessages get_channel_messages = 53; + GetChannelMessagesResponse get_channel_messages_response = 54; + + UpdateContacts update_contacts = 55; + + GetUsers get_users = 56; + GetUsersResponse get_users_response = 57; } } @@ -78,6 +79,10 @@ message Error { string message = 1; } +message Test { + uint64 id = 1; +} + message RegisterProject {} message RegisterProjectResponse { @@ -128,11 +133,12 @@ message ShareWorktree { } message UpdateWorktree { - uint64 project_id = 1; - uint64 worktree_id = 2; - string root_name = 3; - repeated Entry updated_entries = 4; - repeated uint64 removed_entries = 5; + uint64 id = 1; + uint64 project_id = 2; + uint64 worktree_id = 3; + string root_name = 4; + repeated Entry updated_entries = 5; + repeated uint64 removed_entries = 6; } message AddProjectCollaborator { @@ -191,6 +197,7 @@ message UpdateBufferFile { message SaveBuffer { uint64 project_id = 1; uint64 buffer_id = 2; + repeated VectorClockEntry version = 3; } message BufferSaved { @@ -220,10 +227,12 @@ message GetCompletions { uint64 project_id = 1; uint64 buffer_id = 2; Anchor position = 3; + repeated VectorClockEntry version = 4; } message GetCompletionsResponse { repeated Completion completions = 1; + repeated VectorClockEntry version = 2; } message ApplyCompletionAdditionalEdits { @@ -252,6 +261,7 @@ message GetCodeActions { message GetCodeActionsResponse { repeated CodeAction actions = 1; + repeated VectorClockEntry version = 2; } message ApplyCodeAction { @@ -386,6 +396,7 @@ message Worktree { repeated Entry entries = 3; repeated DiagnosticSummary diagnostic_summaries = 4; bool weak = 5; + uint64 next_update_id = 6; } message File { diff --git a/crates/rpc/src/peer.rs b/crates/rpc/src/peer.rs index 0a614e0bed4cab518418c42b4e12b28b3f69b08c..d37aec47678d6ef1858982c180d39940104e1da6 100644 --- a/crates/rpc/src/peer.rs +++ b/crates/rpc/src/peer.rs @@ -265,7 +265,7 @@ impl Peer { .await .ok_or_else(|| anyhow!("connection was closed"))?; if let Some(proto::envelope::Payload::Error(error)) = &response.payload { - Err(anyhow!("request failed").context(error.message.clone())) + Err(anyhow!("RPC request failed - {}", error.message)) } else { T::Response::from_envelope(response) .ok_or_else(|| anyhow!("received response of the wrong type")) @@ -402,40 +402,18 @@ mod tests { assert_eq!( client1 - .request( - client1_conn_id, - proto::OpenBuffer { - project_id: 0, - worktree_id: 1, - path: "path/one".to_string(), - }, - ) + .request(client1_conn_id, proto::Test { id: 1 },) .await .unwrap(), - proto::OpenBufferResponse { - buffer: Some(proto::Buffer { - variant: Some(proto::buffer::Variant::Id(0)) - }), - } + proto::Test { id: 1 } ); assert_eq!( client2 - .request( - client2_conn_id, - proto::OpenBuffer { - project_id: 0, - worktree_id: 2, - path: "path/two".to_string(), - }, - ) + .request(client2_conn_id, proto::Test { id: 2 }) .await .unwrap(), - proto::OpenBufferResponse { - buffer: Some(proto::Buffer { - variant: Some(proto::buffer::Variant::Id(1)) - }) - } + proto::Test { id: 2 } ); client1.disconnect(client1_conn_id); @@ -450,34 +428,9 @@ mod tests { if let Some(envelope) = envelope.downcast_ref::>() { let receipt = envelope.receipt(); peer.respond(receipt, proto::Ack {})? - } else if let Some(envelope) = - envelope.downcast_ref::>() + } else if let Some(envelope) = envelope.downcast_ref::>() { - let message = &envelope.payload; - let receipt = envelope.receipt(); - let response = match message.path.as_str() { - "path/one" => { - assert_eq!(message.worktree_id, 1); - proto::OpenBufferResponse { - buffer: Some(proto::Buffer { - variant: Some(proto::buffer::Variant::Id(0)), - }), - } - } - "path/two" => { - assert_eq!(message.worktree_id, 2); - proto::OpenBufferResponse { - buffer: Some(proto::Buffer { - variant: Some(proto::buffer::Variant::Id(1)), - }), - } - } - _ => { - panic!("unexpected path {}", message.path); - } - }; - - peer.respond(receipt, response)? + peer.respond(envelope.receipt(), envelope.payload.clone())? } else { panic!("unknown message type"); } diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index 9aa9eb61b3e3a1eec866f0f43c224c443b98c360..8093f2551fc4fbb380f64f20dff219c6b2ca5927 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -13,6 +13,7 @@ include!(concat!(env!("OUT_DIR"), "/zed.messages.rs")); pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static { const NAME: &'static str; + const PRIORITY: MessagePriority; fn into_envelope( self, id: u32, @@ -35,6 +36,12 @@ pub trait AnyTypedEnvelope: 'static + Send + Sync { fn payload_type_name(&self) -> &'static str; fn as_any(&self) -> &dyn Any; fn into_any(self: Box) -> Box; + fn is_background(&self) -> bool; +} + +pub enum MessagePriority { + Foreground, + Background, } impl AnyTypedEnvelope for TypedEnvelope { @@ -53,10 +60,14 @@ impl AnyTypedEnvelope for TypedEnvelope { fn into_any(self: Box) -> Box { self } + + fn is_background(&self) -> bool { + matches!(T::PRIORITY, MessagePriority::Background) + } } macro_rules! messages { - ($($name:ident),* $(,)?) => { + ($(($name:ident, $priority:ident)),* $(,)?) => { pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option> { match envelope.payload { $(Some(envelope::Payload::$name(payload)) => { @@ -74,6 +85,7 @@ macro_rules! messages { $( impl EnvelopedMessage for $name { const NAME: &'static str = std::stringify!($name); + const PRIORITY: MessagePriority = MessagePriority::$priority; fn into_envelope( self, @@ -120,59 +132,60 @@ macro_rules! entity_messages { } messages!( - Ack, - AddProjectCollaborator, - ApplyCodeAction, - ApplyCodeActionResponse, - ApplyCompletionAdditionalEdits, - ApplyCompletionAdditionalEditsResponse, - BufferReloaded, - BufferSaved, - ChannelMessageSent, - CloseBuffer, - DiskBasedDiagnosticsUpdated, - DiskBasedDiagnosticsUpdating, - Error, - FormatBuffers, - FormatBuffersResponse, - GetChannelMessages, - GetChannelMessagesResponse, - GetChannels, - GetChannelsResponse, - GetCodeActions, - GetCodeActionsResponse, - GetCompletions, - GetCompletionsResponse, - GetDefinition, - GetDefinitionResponse, - GetUsers, - GetUsersResponse, - JoinChannel, - JoinChannelResponse, - JoinProject, - JoinProjectResponse, - LeaveChannel, - LeaveProject, - OpenBuffer, - OpenBufferResponse, - RegisterProjectResponse, - Ping, - RegisterProject, - RegisterWorktree, - RemoveProjectCollaborator, - SaveBuffer, - SendChannelMessage, - SendChannelMessageResponse, - ShareProject, - ShareWorktree, - UnregisterProject, - UnregisterWorktree, - UnshareProject, - UpdateBuffer, - UpdateBufferFile, - UpdateContacts, - UpdateDiagnosticSummary, - UpdateWorktree, + (Ack, Foreground), + (AddProjectCollaborator, Foreground), + (ApplyCodeAction, Foreground), + (ApplyCodeActionResponse, Foreground), + (ApplyCompletionAdditionalEdits, Foreground), + (ApplyCompletionAdditionalEditsResponse, Foreground), + (BufferReloaded, Foreground), + (BufferSaved, Foreground), + (ChannelMessageSent, Foreground), + (CloseBuffer, Foreground), + (DiskBasedDiagnosticsUpdated, Background), + (DiskBasedDiagnosticsUpdating, Background), + (Error, Foreground), + (FormatBuffers, Foreground), + (FormatBuffersResponse, Foreground), + (GetChannelMessages, Foreground), + (GetChannelMessagesResponse, Foreground), + (GetChannels, Foreground), + (GetChannelsResponse, Foreground), + (GetCodeActions, Background), + (GetCodeActionsResponse, Foreground), + (GetCompletions, Background), + (GetCompletionsResponse, Foreground), + (GetDefinition, Foreground), + (GetDefinitionResponse, Foreground), + (GetUsers, Foreground), + (GetUsersResponse, Foreground), + (JoinChannel, Foreground), + (JoinChannelResponse, Foreground), + (JoinProject, Foreground), + (JoinProjectResponse, Foreground), + (LeaveChannel, Foreground), + (LeaveProject, Foreground), + (OpenBuffer, Foreground), + (OpenBufferResponse, Foreground), + (RegisterProjectResponse, Foreground), + (Ping, Foreground), + (RegisterProject, Foreground), + (RegisterWorktree, Foreground), + (RemoveProjectCollaborator, Foreground), + (SaveBuffer, Foreground), + (SendChannelMessage, Foreground), + (SendChannelMessageResponse, Foreground), + (ShareProject, Foreground), + (ShareWorktree, Foreground), + (Test, Foreground), + (UnregisterProject, Foreground), + (UnregisterWorktree, Foreground), + (UnshareProject, Foreground), + (UpdateBuffer, Foreground), + (UpdateBufferFile, Foreground), + (UpdateContacts, Foreground), + (UpdateDiagnosticSummary, Foreground), + (UpdateWorktree, Foreground), ); request_messages!( @@ -198,7 +211,9 @@ request_messages!( (SendChannelMessage, SendChannelMessageResponse), (ShareProject, Ack), (ShareWorktree, Ack), + (Test, Test), (UpdateBuffer, Ack), + (UpdateWorktree, Ack), ); entity_messages!( @@ -256,12 +271,19 @@ where { /// Write a given protobuf message to the stream. pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> { + #[cfg(any(test, feature = "test-support"))] + const COMPRESSION_LEVEL: i32 = -7; + + #[cfg(not(any(test, feature = "test-support")))] + const COMPRESSION_LEVEL: i32 = 4; + self.encoding_buffer.resize(message.encoded_len(), 0); self.encoding_buffer.clear(); message .encode(&mut self.encoding_buffer) .map_err(|err| io::Error::from(err))?; - let buffer = zstd::stream::encode_all(self.encoding_buffer.as_slice(), 4).unwrap(); + let buffer = + zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL).unwrap(); self.stream.send(WebSocketMessage::Binary(buffer)).await?; Ok(()) } diff --git a/crates/server/Cargo.toml b/crates/server/Cargo.toml index a4357415da4c9014fe5aad6f6ba19181b178c587..38aeac8a6b0a1d82ad9b2f3bf0e8a8f04cc74ceb 100644 --- a/crates/server/Cargo.toml +++ b/crates/server/Cargo.toml @@ -15,7 +15,6 @@ required-features = ["seed-support"] [dependencies] collections = { path = "../collections" } rpc = { path = "../rpc" } - anyhow = "1.0.40" async-std = { version = "1.8.0", features = ["attributes"] } async-trait = "0.1.50" @@ -57,12 +56,12 @@ features = ["runtime-async-std-rustls", "postgres", "time", "uuid"] [dev-dependencies] collections = { path = "../collections", features = ["test-support"] } -gpui = { path = "../gpui" } +gpui = { path = "../gpui", features = ["test-support"] } +rpc = { path = "../rpc", features = ["test-support"] } zed = { path = "../zed", features = ["test-support"] } ctor = "0.1" env_logger = "0.8" util = { path = "../util" } - lazy_static = "1.4" serde_json = { version = "1.0.64", features = ["preserve_order"] } diff --git a/crates/server/src/api.rs b/crates/server/src/api.rs index 0999a28d9027d5829fdea6c9b4a64616dbca7a7a..69b60fe9ec4ae21359e5cdfe932d244b1aea67f6 100644 --- a/crates/server/src/api.rs +++ b/crates/server/src/api.rs @@ -111,7 +111,7 @@ async fn create_access_token(request: Request) -> tide::Result { .get_user_by_github_login(request.param("github_login")?) .await? .ok_or_else(|| surf::Error::from_str(StatusCode::NotFound, "user not found"))?; - let access_token = auth::create_access_token(request.db(), user.id).await?; + let access_token = auth::create_access_token(request.db().as_ref(), user.id).await?; #[derive(Deserialize)] struct QueryParams { diff --git a/crates/server/src/auth.rs b/crates/server/src/auth.rs index 1fbd137d1298e9f1c8b79fb2ae6c190cf3c427b1..91136b46d065afbbfbaad7340d312e15b4af0166 100644 --- a/crates/server/src/auth.rs +++ b/crates/server/src/auth.rs @@ -234,7 +234,7 @@ async fn get_auth_callback(mut request: Request) -> tide::Result { let mut user_id = user.id; if let Some(impersonated_login) = app_sign_in_params.impersonate { log::info!("attempting to impersonate user @{}", impersonated_login); - if let Some(user) = request.db().get_users_by_ids([user_id]).await?.first() { + if let Some(user) = request.db().get_users_by_ids(vec![user_id]).await?.first() { if user.admin { user_id = request.db().create_user(&impersonated_login, false).await?; log::info!("impersonating user {}", user_id.0); @@ -244,7 +244,7 @@ async fn get_auth_callback(mut request: Request) -> tide::Result { } } - let access_token = create_access_token(request.db(), user_id).await?; + let access_token = create_access_token(request.db().as_ref(), user_id).await?; let encrypted_access_token = encrypt_access_token( &access_token, app_sign_in_params.native_app_public_key.clone(), @@ -267,7 +267,7 @@ async fn post_sign_out(mut request: Request) -> tide::Result { const MAX_ACCESS_TOKENS_TO_STORE: usize = 8; -pub async fn create_access_token(db: &db::Db, user_id: UserId) -> tide::Result { +pub async fn create_access_token(db: &dyn db::Db, user_id: UserId) -> tide::Result { let access_token = zed_auth::random_token(); let access_token_hash = hash_access_token(&access_token).context("failed to hash access token")?; diff --git a/crates/server/src/db.rs b/crates/server/src/db.rs index a48673f30fbeb70ff47f9fbab2b4b58aef23fbba..37e35be5f8ca21bf810cee15147db5bd01d4019c 100644 --- a/crates/server/src/db.rs +++ b/crates/server/src/db.rs @@ -1,11 +1,12 @@ use anyhow::Context; +use anyhow::Result; +pub use async_sqlx_session::PostgresSessionStore as SessionStore; use async_std::task::{block_on, yield_now}; +use async_trait::async_trait; use serde::Serialize; -use sqlx::{types::Uuid, FromRow, Result}; -use time::OffsetDateTime; - -pub use async_sqlx_session::PostgresSessionStore as SessionStore; pub use sqlx::postgres::PgPoolOptions as DbOptions; +use sqlx::{types::Uuid, FromRow}; +use time::OffsetDateTime; macro_rules! test_support { ($self:ident, { $($token:tt)* }) => {{ @@ -21,13 +22,77 @@ macro_rules! test_support { }}; } -#[derive(Clone)] -pub struct Db { +#[async_trait] +pub trait Db: Send + Sync { + async fn create_signup( + &self, + github_login: &str, + email_address: &str, + about: &str, + wants_releases: bool, + wants_updates: bool, + wants_community: bool, + ) -> Result; + async fn get_all_signups(&self) -> Result>; + async fn destroy_signup(&self, id: SignupId) -> Result<()>; + async fn create_user(&self, github_login: &str, admin: bool) -> Result; + async fn get_all_users(&self) -> Result>; + async fn get_user_by_id(&self, id: UserId) -> Result>; + async fn get_users_by_ids(&self, ids: Vec) -> Result>; + async fn get_user_by_github_login(&self, github_login: &str) -> Result>; + async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>; + async fn destroy_user(&self, id: UserId) -> Result<()>; + async fn create_access_token_hash( + &self, + user_id: UserId, + access_token_hash: &str, + max_access_token_count: usize, + ) -> Result<()>; + async fn get_access_token_hashes(&self, user_id: UserId) -> Result>; + #[cfg(any(test, feature = "seed-support"))] + async fn find_org_by_slug(&self, slug: &str) -> Result>; + #[cfg(any(test, feature = "seed-support"))] + async fn create_org(&self, name: &str, slug: &str) -> Result; + #[cfg(any(test, feature = "seed-support"))] + async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>; + #[cfg(any(test, feature = "seed-support"))] + async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result; + #[cfg(any(test, feature = "seed-support"))] + async fn get_org_channels(&self, org_id: OrgId) -> Result>; + async fn get_accessible_channels(&self, user_id: UserId) -> Result>; + async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId) + -> Result; + #[cfg(any(test, feature = "seed-support"))] + async fn add_channel_member( + &self, + channel_id: ChannelId, + user_id: UserId, + is_admin: bool, + ) -> Result<()>; + async fn create_channel_message( + &self, + channel_id: ChannelId, + sender_id: UserId, + body: &str, + timestamp: OffsetDateTime, + nonce: u128, + ) -> Result; + async fn get_channel_messages( + &self, + channel_id: ChannelId, + count: usize, + before_id: Option, + ) -> Result>; + #[cfg(test)] + async fn teardown(&self, name: &str, url: &str); +} + +pub struct PostgresDb { pool: sqlx::PgPool, test_mode: bool, } -impl Db { +impl PostgresDb { pub async fn new(url: &str, max_connections: u32) -> tide::Result { let pool = DbOptions::new() .max_connections(max_connections) @@ -39,10 +104,12 @@ impl Db { test_mode: false, }) } +} +#[async_trait] +impl Db for PostgresDb { // signups - - pub async fn create_signup( + async fn create_signup( &self, github_login: &str, email_address: &str, @@ -64,7 +131,7 @@ impl Db { VALUES ($1, $2, $3, $4, $5, $6) RETURNING id "; - sqlx::query_scalar(query) + Ok(sqlx::query_scalar(query) .bind(github_login) .bind(email_address) .bind(about) @@ -73,31 +140,31 @@ impl Db { .bind(wants_community) .fetch_one(&self.pool) .await - .map(SignupId) + .map(SignupId)?) }) } - pub async fn get_all_signups(&self) -> Result> { + async fn get_all_signups(&self) -> Result> { test_support!(self, { let query = "SELECT * FROM signups ORDER BY github_login ASC"; - sqlx::query_as(query).fetch_all(&self.pool).await + Ok(sqlx::query_as(query).fetch_all(&self.pool).await?) }) } - pub async fn destroy_signup(&self, id: SignupId) -> Result<()> { + async fn destroy_signup(&self, id: SignupId) -> Result<()> { test_support!(self, { let query = "DELETE FROM signups WHERE id = $1"; - sqlx::query(query) + Ok(sqlx::query(query) .bind(id.0) .execute(&self.pool) .await - .map(drop) + .map(drop)?) }) } // users - pub async fn create_user(&self, github_login: &str, admin: bool) -> Result { + async fn create_user(&self, github_login: &str, admin: bool) -> Result { test_support!(self, { let query = " INSERT INTO users (github_login, admin) @@ -105,31 +172,28 @@ impl Db { ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login RETURNING id "; - sqlx::query_scalar(query) + Ok(sqlx::query_scalar(query) .bind(github_login) .bind(admin) .fetch_one(&self.pool) .await - .map(UserId) + .map(UserId)?) }) } - pub async fn get_all_users(&self) -> Result> { + async fn get_all_users(&self) -> Result> { test_support!(self, { let query = "SELECT * FROM users ORDER BY github_login ASC"; - sqlx::query_as(query).fetch_all(&self.pool).await + Ok(sqlx::query_as(query).fetch_all(&self.pool).await?) }) } - pub async fn get_user_by_id(&self, id: UserId) -> Result> { - let users = self.get_users_by_ids([id]).await?; + async fn get_user_by_id(&self, id: UserId) -> Result> { + let users = self.get_users_by_ids(vec![id]).await?; Ok(users.into_iter().next()) } - pub async fn get_users_by_ids( - &self, - ids: impl IntoIterator, - ) -> Result> { + async fn get_users_by_ids(&self, ids: Vec) -> Result> { let ids = ids.into_iter().map(|id| id.0).collect::>(); test_support!(self, { let query = " @@ -138,33 +202,36 @@ impl Db { WHERE users.id = ANY ($1) "; - sqlx::query_as(query).bind(&ids).fetch_all(&self.pool).await + Ok(sqlx::query_as(query) + .bind(&ids) + .fetch_all(&self.pool) + .await?) }) } - pub async fn get_user_by_github_login(&self, github_login: &str) -> Result> { + async fn get_user_by_github_login(&self, github_login: &str) -> Result> { test_support!(self, { let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1"; - sqlx::query_as(query) + Ok(sqlx::query_as(query) .bind(github_login) .fetch_optional(&self.pool) - .await + .await?) }) } - pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> { + async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> { test_support!(self, { let query = "UPDATE users SET admin = $1 WHERE id = $2"; - sqlx::query(query) + Ok(sqlx::query(query) .bind(is_admin) .bind(id.0) .execute(&self.pool) .await - .map(drop) + .map(drop)?) }) } - pub async fn destroy_user(&self, id: UserId) -> Result<()> { + async fn destroy_user(&self, id: UserId) -> Result<()> { test_support!(self, { let query = "DELETE FROM access_tokens WHERE user_id = $1;"; sqlx::query(query) @@ -173,17 +240,17 @@ impl Db { .await .map(drop)?; let query = "DELETE FROM users WHERE id = $1;"; - sqlx::query(query) + Ok(sqlx::query(query) .bind(id.0) .execute(&self.pool) .await - .map(drop) + .map(drop)?) }) } // access tokens - pub async fn create_access_token_hash( + async fn create_access_token_hash( &self, user_id: UserId, access_token_hash: &str, @@ -216,11 +283,11 @@ impl Db { .bind(max_access_token_count as u32) .execute(&mut tx) .await?; - tx.commit().await + Ok(tx.commit().await?) }) } - pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result> { + async fn get_access_token_hashes(&self, user_id: UserId) -> Result> { test_support!(self, { let query = " SELECT hash @@ -228,10 +295,10 @@ impl Db { WHERE user_id = $1 ORDER BY id DESC "; - sqlx::query_scalar(query) + Ok(sqlx::query_scalar(query) .bind(user_id.0) .fetch_all(&self.pool) - .await + .await?) }) } @@ -239,82 +306,77 @@ impl Db { #[allow(unused)] // Help rust-analyzer #[cfg(any(test, feature = "seed-support"))] - pub async fn find_org_by_slug(&self, slug: &str) -> Result> { + async fn find_org_by_slug(&self, slug: &str) -> Result> { test_support!(self, { let query = " SELECT * FROM orgs WHERE slug = $1 "; - sqlx::query_as(query) + Ok(sqlx::query_as(query) .bind(slug) .fetch_optional(&self.pool) - .await + .await?) }) } #[cfg(any(test, feature = "seed-support"))] - pub async fn create_org(&self, name: &str, slug: &str) -> Result { + async fn create_org(&self, name: &str, slug: &str) -> Result { test_support!(self, { let query = " INSERT INTO orgs (name, slug) VALUES ($1, $2) RETURNING id "; - sqlx::query_scalar(query) + Ok(sqlx::query_scalar(query) .bind(name) .bind(slug) .fetch_one(&self.pool) .await - .map(OrgId) + .map(OrgId)?) }) } #[cfg(any(test, feature = "seed-support"))] - pub async fn add_org_member( - &self, - org_id: OrgId, - user_id: UserId, - is_admin: bool, - ) -> Result<()> { + async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> { test_support!(self, { let query = " INSERT INTO org_memberships (org_id, user_id, admin) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING "; - sqlx::query(query) + Ok(sqlx::query(query) .bind(org_id.0) .bind(user_id.0) .bind(is_admin) .execute(&self.pool) .await - .map(drop) + .map(drop)?) }) } // channels #[cfg(any(test, feature = "seed-support"))] - pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result { + async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result { test_support!(self, { let query = " INSERT INTO channels (owner_id, owner_is_user, name) VALUES ($1, false, $2) RETURNING id "; - sqlx::query_scalar(query) + Ok(sqlx::query_scalar(query) .bind(org_id.0) .bind(name) .fetch_one(&self.pool) .await - .map(ChannelId) + .map(ChannelId)?) }) } #[allow(unused)] // Help rust-analyzer #[cfg(any(test, feature = "seed-support"))] - pub async fn get_org_channels(&self, org_id: OrgId) -> Result> { + async fn get_org_channels(&self, org_id: OrgId) -> Result> { test_support!(self, { let query = " SELECT * @@ -323,32 +385,32 @@ impl Db { channels.owner_is_user = false AND channels.owner_id = $1 "; - sqlx::query_as(query) + Ok(sqlx::query_as(query) .bind(org_id.0) .fetch_all(&self.pool) - .await + .await?) }) } - pub async fn get_accessible_channels(&self, user_id: UserId) -> Result> { + async fn get_accessible_channels(&self, user_id: UserId) -> Result> { test_support!(self, { let query = " SELECT - channels.id, channels.name + channels.* FROM channel_memberships, channels WHERE channel_memberships.user_id = $1 AND channel_memberships.channel_id = channels.id "; - sqlx::query_as(query) + Ok(sqlx::query_as(query) .bind(user_id.0) .fetch_all(&self.pool) - .await + .await?) }) } - pub async fn can_user_access_channel( + async fn can_user_access_channel( &self, user_id: UserId, channel_id: ChannelId, @@ -360,17 +422,17 @@ impl Db { WHERE user_id = $1 AND channel_id = $2 LIMIT 1 "; - sqlx::query_scalar::<_, i32>(query) + Ok(sqlx::query_scalar::<_, i32>(query) .bind(user_id.0) .bind(channel_id.0) .fetch_optional(&self.pool) .await - .map(|e| e.is_some()) + .map(|e| e.is_some())?) }) } #[cfg(any(test, feature = "seed-support"))] - pub async fn add_channel_member( + async fn add_channel_member( &self, channel_id: ChannelId, user_id: UserId, @@ -382,19 +444,19 @@ impl Db { VALUES ($1, $2, $3) ON CONFLICT DO NOTHING "; - sqlx::query(query) + Ok(sqlx::query(query) .bind(channel_id.0) .bind(user_id.0) .bind(is_admin) .execute(&self.pool) .await - .map(drop) + .map(drop)?) }) } // messages - pub async fn create_channel_message( + async fn create_channel_message( &self, channel_id: ChannelId, sender_id: UserId, @@ -409,7 +471,7 @@ impl Db { ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce RETURNING id "; - sqlx::query_scalar(query) + Ok(sqlx::query_scalar(query) .bind(channel_id.0) .bind(sender_id.0) .bind(body) @@ -417,11 +479,11 @@ impl Db { .bind(Uuid::from_u128(nonce)) .fetch_one(&self.pool) .await - .map(MessageId) + .map(MessageId)?) }) } - pub async fn get_channel_messages( + async fn get_channel_messages( &self, channel_id: ChannelId, count: usize, @@ -431,7 +493,7 @@ impl Db { let query = r#" SELECT * FROM ( SELECT - id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce + id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce FROM channel_messages WHERE @@ -442,12 +504,34 @@ impl Db { ) as recent_messages ORDER BY id ASC "#; - sqlx::query_as(query) + Ok(sqlx::query_as(query) .bind(channel_id.0) .bind(before_id.unwrap_or(MessageId::MAX)) .bind(count as i64) .fetch_all(&self.pool) + .await?) + }) + } + + #[cfg(test)] + async fn teardown(&self, name: &str, url: &str) { + use util::ResultExt; + + test_support!(self, { + let query = " + SELECT pg_terminate_backend(pg_stat_activity.pid) + FROM pg_stat_activity + WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid(); + "; + sqlx::query(query) + .bind(name) + .execute(&self.pool) .await + .log_err(); + self.pool.close().await; + ::drop_database(url) + .await + .log_err(); }) } } @@ -479,7 +563,7 @@ macro_rules! id_type { } id_type!(UserId); -#[derive(Debug, FromRow, Serialize, PartialEq)] +#[derive(Clone, Debug, FromRow, Serialize, PartialEq)] pub struct User { pub id: UserId, pub github_login: String, @@ -507,16 +591,19 @@ pub struct Signup { } id_type!(ChannelId); -#[derive(Debug, FromRow, Serialize)] +#[derive(Clone, Debug, FromRow, Serialize)] pub struct Channel { pub id: ChannelId, pub name: String, + pub owner_id: i32, + pub owner_is_user: bool, } id_type!(MessageId); -#[derive(Debug, FromRow)] +#[derive(Clone, Debug, FromRow)] pub struct ChannelMessage { pub id: MessageId, + pub channel_id: ChannelId, pub sender_id: UserId, pub body: String, pub sent_at: OffsetDateTime, @@ -526,6 +613,9 @@ pub struct ChannelMessage { #[cfg(test)] pub mod tests { use super::*; + use anyhow::anyhow; + use collections::BTreeMap; + use gpui::{executor::Background, TestAppContext}; use lazy_static::lazy_static; use parking_lot::Mutex; use rand::prelude::*; @@ -533,217 +623,119 @@ pub mod tests { migrate::{MigrateDatabase, Migrator}, Postgres, }; - use std::{ - mem, - path::Path, - sync::atomic::{AtomicUsize, Ordering::SeqCst}, - }; - use util::ResultExt as _; - - pub struct TestDb { - pub db: Option, - pub name: String, - pub url: String, - } + use std::{path::Path, sync::Arc}; + use util::post_inc; - lazy_static! { - static ref DB_POOL: Mutex> = Default::default(); - static ref DB_COUNT: AtomicUsize = Default::default(); - } + #[gpui::test] + async fn test_get_users_by_ids(cx: TestAppContext) { + for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] { + let db = test_db.db(); - impl TestDb { - pub fn new() -> Self { - DB_COUNT.fetch_add(1, SeqCst); - let mut pool = DB_POOL.lock(); - if let Some(db) = pool.pop() { - db.truncate(); - db - } else { - let mut rng = StdRng::from_entropy(); - let name = format!("zed-test-{}", rng.gen::()); - let url = format!("postgres://postgres@localhost/{}", name); - let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")); - let db = block_on(async { - Postgres::create_database(&url) - .await - .expect("failed to create test db"); - let mut db = Db::new(&url, 5).await.unwrap(); - db.test_mode = true; - let migrator = Migrator::new(migrations_path).await.unwrap(); - migrator.run(&db.pool).await.unwrap(); - db - }); - - Self { - db: Some(db), - name, - url, - } - } - } + let user = db.create_user("user", false).await.unwrap(); + let friend1 = db.create_user("friend-1", false).await.unwrap(); + let friend2 = db.create_user("friend-2", false).await.unwrap(); + let friend3 = db.create_user("friend-3", false).await.unwrap(); - pub fn db(&self) -> &Db { - self.db.as_ref().unwrap() - } - - fn truncate(&self) { - block_on(async { - let query = " - SELECT tablename FROM pg_tables - WHERE schemaname = 'public'; - "; - let table_names = sqlx::query_scalar::<_, String>(query) - .fetch_all(&self.db().pool) + assert_eq!( + db.get_users_by_ids(vec![user, friend1, friend2, friend3]) .await - .unwrap(); - sqlx::query(&format!( - "TRUNCATE TABLE {} RESTART IDENTITY", - table_names.join(", ") - )) - .execute(&self.db().pool) - .await - .unwrap(); - }) - } - - async fn teardown(mut self) -> Result<()> { - let db = self.db.take().unwrap(); - let query = " - SELECT pg_terminate_backend(pg_stat_activity.pid) - FROM pg_stat_activity - WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid(); - "; - sqlx::query(query) - .bind(&self.name) - .execute(&db.pool) - .await?; - db.pool.close().await; - Postgres::drop_database(&self.url).await?; - Ok(()) - } - } - - impl Drop for TestDb { - fn drop(&mut self) { - if let Some(db) = self.db.take() { - DB_POOL.lock().push(TestDb { - db: Some(db), - name: mem::take(&mut self.name), - url: mem::take(&mut self.url), - }); - if DB_COUNT.fetch_sub(1, SeqCst) == 1 { - block_on(async move { - let mut pool = DB_POOL.lock(); - for db in pool.drain(..) { - db.teardown().await.log_err(); - } - }); - } - } + .unwrap(), + vec![ + User { + id: user, + github_login: "user".to_string(), + admin: false, + }, + User { + id: friend1, + github_login: "friend-1".to_string(), + admin: false, + }, + User { + id: friend2, + github_login: "friend-2".to_string(), + admin: false, + }, + User { + id: friend3, + github_login: "friend-3".to_string(), + admin: false, + } + ] + ); } } #[gpui::test] - async fn test_get_users_by_ids() { - let test_db = TestDb::new(); - let db = test_db.db(); - - let user = db.create_user("user", false).await.unwrap(); - let friend1 = db.create_user("friend-1", false).await.unwrap(); - let friend2 = db.create_user("friend-2", false).await.unwrap(); - let friend3 = db.create_user("friend-3", false).await.unwrap(); - - assert_eq!( - db.get_users_by_ids([user, friend1, friend2, friend3]) + async fn test_recent_channel_messages(cx: TestAppContext) { + for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] { + let db = test_db.db(); + let user = db.create_user("user", false).await.unwrap(); + let org = db.create_org("org", "org").await.unwrap(); + let channel = db.create_org_channel(org, "channel").await.unwrap(); + for i in 0..10 { + db.create_channel_message( + channel, + user, + &i.to_string(), + OffsetDateTime::now_utc(), + i, + ) .await - .unwrap(), - vec![ - User { - id: user, - github_login: "user".to_string(), - admin: false, - }, - User { - id: friend1, - github_login: "friend-1".to_string(), - admin: false, - }, - User { - id: friend2, - github_login: "friend-2".to_string(), - admin: false, - }, - User { - id: friend3, - github_login: "friend-3".to_string(), - admin: false, - } - ] - ); - } + .unwrap(); + } - #[gpui::test] - async fn test_recent_channel_messages() { - let test_db = TestDb::new(); - let db = test_db.db(); - let user = db.create_user("user", false).await.unwrap(); - let org = db.create_org("org", "org").await.unwrap(); - let channel = db.create_org_channel(org, "channel").await.unwrap(); - for i in 0..10 { - db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i) + let messages = db.get_channel_messages(channel, 5, None).await.unwrap(); + assert_eq!( + messages.iter().map(|m| &m.body).collect::>(), + ["5", "6", "7", "8", "9"] + ); + + let prev_messages = db + .get_channel_messages(channel, 4, Some(messages[0].id)) .await .unwrap(); + assert_eq!( + prev_messages.iter().map(|m| &m.body).collect::>(), + ["1", "2", "3", "4"] + ); } - - let messages = db.get_channel_messages(channel, 5, None).await.unwrap(); - assert_eq!( - messages.iter().map(|m| &m.body).collect::>(), - ["5", "6", "7", "8", "9"] - ); - - let prev_messages = db - .get_channel_messages(channel, 4, Some(messages[0].id)) - .await - .unwrap(); - assert_eq!( - prev_messages.iter().map(|m| &m.body).collect::>(), - ["1", "2", "3", "4"] - ); } #[gpui::test] - async fn test_channel_message_nonces() { - let test_db = TestDb::new(); - let db = test_db.db(); - let user = db.create_user("user", false).await.unwrap(); - let org = db.create_org("org", "org").await.unwrap(); - let channel = db.create_org_channel(org, "channel").await.unwrap(); - - let msg1_id = db - .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1) - .await - .unwrap(); - let msg2_id = db - .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2) - .await - .unwrap(); - let msg3_id = db - .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1) - .await - .unwrap(); - let msg4_id = db - .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2) - .await - .unwrap(); + async fn test_channel_message_nonces(cx: TestAppContext) { + for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] { + let db = test_db.db(); + let user = db.create_user("user", false).await.unwrap(); + let org = db.create_org("org", "org").await.unwrap(); + let channel = db.create_org_channel(org, "channel").await.unwrap(); + + let msg1_id = db + .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1) + .await + .unwrap(); + let msg2_id = db + .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2) + .await + .unwrap(); + let msg3_id = db + .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1) + .await + .unwrap(); + let msg4_id = db + .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2) + .await + .unwrap(); - assert_ne!(msg1_id, msg2_id); - assert_eq!(msg1_id, msg3_id); - assert_eq!(msg2_id, msg4_id); + assert_ne!(msg1_id, msg2_id); + assert_eq!(msg1_id, msg3_id); + assert_eq!(msg2_id, msg4_id); + } } #[gpui::test] async fn test_create_access_tokens() { - let test_db = TestDb::new(); + let test_db = TestDb::postgres(); let db = test_db.db(); let user = db.create_user("the-user", false).await.unwrap(); @@ -772,4 +764,359 @@ pub mod tests { &["h5".to_string(), "h4".to_string(), "h3".to_string()] ); } + + pub struct TestDb { + pub db: Option>, + pub name: String, + pub url: String, + } + + impl TestDb { + pub fn postgres() -> Self { + lazy_static! { + static ref LOCK: Mutex<()> = Mutex::new(()); + } + + let _guard = LOCK.lock(); + let mut rng = StdRng::from_entropy(); + let name = format!("zed-test-{}", rng.gen::()); + let url = format!("postgres://postgres@localhost/{}", name); + let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations")); + let db = block_on(async { + Postgres::create_database(&url) + .await + .expect("failed to create test db"); + let mut db = PostgresDb::new(&url, 5).await.unwrap(); + db.test_mode = true; + let migrator = Migrator::new(migrations_path).await.unwrap(); + migrator.run(&db.pool).await.unwrap(); + db + }); + Self { + db: Some(Arc::new(db)), + name, + url, + } + } + + pub fn fake(background: Arc) -> Self { + Self { + db: Some(Arc::new(FakeDb::new(background))), + name: "fake".to_string(), + url: "fake".to_string(), + } + } + + pub fn db(&self) -> &Arc { + self.db.as_ref().unwrap() + } + } + + impl Drop for TestDb { + fn drop(&mut self) { + if let Some(db) = self.db.take() { + block_on(db.teardown(&self.name, &self.url)); + } + } + } + + pub struct FakeDb { + background: Arc, + users: Mutex>, + next_user_id: Mutex, + orgs: Mutex>, + next_org_id: Mutex, + org_memberships: Mutex>, + channels: Mutex>, + next_channel_id: Mutex, + channel_memberships: Mutex>, + channel_messages: Mutex>, + next_channel_message_id: Mutex, + } + + impl FakeDb { + pub fn new(background: Arc) -> Self { + Self { + background, + users: Default::default(), + next_user_id: Mutex::new(1), + orgs: Default::default(), + next_org_id: Mutex::new(1), + org_memberships: Default::default(), + channels: Default::default(), + next_channel_id: Mutex::new(1), + channel_memberships: Default::default(), + channel_messages: Default::default(), + next_channel_message_id: Mutex::new(1), + } + } + } + + #[async_trait] + impl Db for FakeDb { + async fn create_signup( + &self, + _github_login: &str, + _email_address: &str, + _about: &str, + _wants_releases: bool, + _wants_updates: bool, + _wants_community: bool, + ) -> Result { + unimplemented!() + } + + async fn get_all_signups(&self) -> Result> { + unimplemented!() + } + + async fn destroy_signup(&self, _id: SignupId) -> Result<()> { + unimplemented!() + } + + async fn create_user(&self, github_login: &str, admin: bool) -> Result { + self.background.simulate_random_delay().await; + + let mut users = self.users.lock(); + if let Some(user) = users + .values() + .find(|user| user.github_login == github_login) + { + Ok(user.id) + } else { + let user_id = UserId(post_inc(&mut *self.next_user_id.lock())); + users.insert( + user_id, + User { + id: user_id, + github_login: github_login.to_string(), + admin, + }, + ); + Ok(user_id) + } + } + + async fn get_all_users(&self) -> Result> { + unimplemented!() + } + + async fn get_user_by_id(&self, id: UserId) -> Result> { + Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next()) + } + + async fn get_users_by_ids(&self, ids: Vec) -> Result> { + self.background.simulate_random_delay().await; + let users = self.users.lock(); + Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect()) + } + + async fn get_user_by_github_login(&self, _github_login: &str) -> Result> { + unimplemented!() + } + + async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> { + unimplemented!() + } + + async fn destroy_user(&self, _id: UserId) -> Result<()> { + unimplemented!() + } + + async fn create_access_token_hash( + &self, + _user_id: UserId, + _access_token_hash: &str, + _max_access_token_count: usize, + ) -> Result<()> { + unimplemented!() + } + + async fn get_access_token_hashes(&self, _user_id: UserId) -> Result> { + unimplemented!() + } + + async fn find_org_by_slug(&self, _slug: &str) -> Result> { + unimplemented!() + } + + async fn create_org(&self, name: &str, slug: &str) -> Result { + self.background.simulate_random_delay().await; + let mut orgs = self.orgs.lock(); + if orgs.values().any(|org| org.slug == slug) { + Err(anyhow!("org already exists")) + } else { + let org_id = OrgId(post_inc(&mut *self.next_org_id.lock())); + orgs.insert( + org_id, + Org { + id: org_id, + name: name.to_string(), + slug: slug.to_string(), + }, + ); + Ok(org_id) + } + } + + async fn add_org_member( + &self, + org_id: OrgId, + user_id: UserId, + is_admin: bool, + ) -> Result<()> { + self.background.simulate_random_delay().await; + if !self.orgs.lock().contains_key(&org_id) { + return Err(anyhow!("org does not exist")); + } + if !self.users.lock().contains_key(&user_id) { + return Err(anyhow!("user does not exist")); + } + + self.org_memberships + .lock() + .entry((org_id, user_id)) + .or_insert(is_admin); + Ok(()) + } + + async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result { + self.background.simulate_random_delay().await; + if !self.orgs.lock().contains_key(&org_id) { + return Err(anyhow!("org does not exist")); + } + + let mut channels = self.channels.lock(); + let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock())); + channels.insert( + channel_id, + Channel { + id: channel_id, + name: name.to_string(), + owner_id: org_id.0, + owner_is_user: false, + }, + ); + Ok(channel_id) + } + + async fn get_org_channels(&self, org_id: OrgId) -> Result> { + self.background.simulate_random_delay().await; + Ok(self + .channels + .lock() + .values() + .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0) + .cloned() + .collect()) + } + + async fn get_accessible_channels(&self, user_id: UserId) -> Result> { + self.background.simulate_random_delay().await; + let channels = self.channels.lock(); + let memberships = self.channel_memberships.lock(); + Ok(channels + .values() + .filter(|channel| memberships.contains_key(&(channel.id, user_id))) + .cloned() + .collect()) + } + + async fn can_user_access_channel( + &self, + user_id: UserId, + channel_id: ChannelId, + ) -> Result { + self.background.simulate_random_delay().await; + Ok(self + .channel_memberships + .lock() + .contains_key(&(channel_id, user_id))) + } + + async fn add_channel_member( + &self, + channel_id: ChannelId, + user_id: UserId, + is_admin: bool, + ) -> Result<()> { + self.background.simulate_random_delay().await; + if !self.channels.lock().contains_key(&channel_id) { + return Err(anyhow!("channel does not exist")); + } + if !self.users.lock().contains_key(&user_id) { + return Err(anyhow!("user does not exist")); + } + + self.channel_memberships + .lock() + .entry((channel_id, user_id)) + .or_insert(is_admin); + Ok(()) + } + + async fn create_channel_message( + &self, + channel_id: ChannelId, + sender_id: UserId, + body: &str, + timestamp: OffsetDateTime, + nonce: u128, + ) -> Result { + self.background.simulate_random_delay().await; + if !self.channels.lock().contains_key(&channel_id) { + return Err(anyhow!("channel does not exist")); + } + if !self.users.lock().contains_key(&sender_id) { + return Err(anyhow!("user does not exist")); + } + + let mut messages = self.channel_messages.lock(); + if let Some(message) = messages + .values() + .find(|message| message.nonce.as_u128() == nonce) + { + Ok(message.id) + } else { + let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock())); + messages.insert( + message_id, + ChannelMessage { + id: message_id, + channel_id, + sender_id, + body: body.to_string(), + sent_at: timestamp, + nonce: Uuid::from_u128(nonce), + }, + ); + Ok(message_id) + } + } + + async fn get_channel_messages( + &self, + channel_id: ChannelId, + count: usize, + before_id: Option, + ) -> Result> { + let mut messages = self + .channel_messages + .lock() + .values() + .rev() + .filter(|message| { + message.channel_id == channel_id + && message.id < before_id.unwrap_or(MessageId::MAX) + }) + .take(count) + .cloned() + .collect::>(); + dbg!(count, before_id, &messages); + messages.sort_unstable_by_key(|message| message.id); + Ok(messages) + } + + async fn teardown(&self, _name: &str, _url: &str) {} + } } diff --git a/crates/server/src/main.rs b/crates/server/src/main.rs index 3301fb24a90f4a87908421caade36dd2ccf39eb7..47c8c82190bfbab775b981ebdfa39875a83f764f 100644 --- a/crates/server/src/main.rs +++ b/crates/server/src/main.rs @@ -20,7 +20,7 @@ use anyhow::Result; use async_std::net::TcpListener; use async_trait::async_trait; use auth::RequestExt as _; -use db::Db; +use db::{Db, PostgresDb}; use handlebars::{Handlebars, TemplateRenderError}; use parking_lot::RwLock; use rust_embed::RustEmbed; @@ -49,7 +49,7 @@ pub struct Config { } pub struct AppState { - db: Db, + db: Arc, handlebars: RwLock>, auth_client: auth::Client, github_client: Arc, @@ -59,7 +59,7 @@ pub struct AppState { impl AppState { async fn new(config: Config) -> tide::Result> { - let db = Db::new(&config.database_url, 5).await?; + let db = PostgresDb::new(&config.database_url, 5).await?; let github_client = github::AppClient::new(config.github_app_id, config.github_private_key.clone()); let repo_client = github_client @@ -68,7 +68,7 @@ impl AppState { .context("failed to initialize github client")?; let this = Self { - db, + db: Arc::new(db), handlebars: Default::default(), auth_client: auth::build_client(&config.github_client_id, &config.github_client_secret), github_client, @@ -112,7 +112,7 @@ impl AppState { #[async_trait] trait RequestExt { async fn layout_data(&mut self) -> tide::Result>; - fn db(&self) -> &Db; + fn db(&self) -> &Arc; } #[async_trait] @@ -126,7 +126,7 @@ impl RequestExt for Request { Ok(self.ext::>().unwrap().clone()) } - fn db(&self) -> &Db { + fn db(&self) -> &Arc { &self.state().db } } diff --git a/crates/server/src/rpc.rs b/crates/server/src/rpc.rs index c1d36ef3c6075280291f1d3cbd84fed6c458672d..adb0592df59367f18f790ad208da3fbb81788231 100644 --- a/crates/server/src/rpc.rs +++ b/crates/server/src/rpc.rs @@ -9,9 +9,8 @@ use anyhow::anyhow; use async_std::task; use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream}; use collections::{HashMap, HashSet}; -use futures::{future::BoxFuture, FutureExt, StreamExt}; +use futures::{channel::mpsc, future::BoxFuture, FutureExt, SinkExt, StreamExt}; use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard}; -use postage::{mpsc, prelude::Sink as _}; use rpc::{ proto::{self, AnyTypedEnvelope, EnvelopedMessage, RequestMessage}, Connection, ConnectionId, Peer, TypedEnvelope, @@ -38,9 +37,15 @@ pub struct Server { store: RwLock, app_state: Arc, handlers: HashMap, - notifications: Option>, + notifications: Option>, } +pub trait Executor { + fn spawn_detached>(&self, future: F); +} + +pub struct RealExecutor; + const MESSAGE_COUNT_PER_PAGE: usize = 100; const MAX_MESSAGE_LEN: usize = 1024; @@ -48,7 +53,7 @@ impl Server { pub fn new( app_state: Arc, peer: Arc, - notifications: Option>, + notifications: Option>, ) -> Arc { let mut server = Self { peer, @@ -69,7 +74,7 @@ impl Server { .add_request_handler(Server::register_worktree) .add_message_handler(Server::unregister_worktree) .add_request_handler(Server::share_worktree) - .add_message_handler(Server::update_worktree) + .add_request_handler(Server::update_worktree) .add_message_handler(Server::update_diagnostic_summary) .add_message_handler(Server::disk_based_diagnostics_updating) .add_message_handler(Server::disk_based_diagnostics_updated) @@ -144,12 +149,13 @@ impl Server { }) } - pub fn handle_connection( + pub fn handle_connection( self: &Arc, connection: Connection, addr: String, user_id: UserId, - mut send_connection_id: Option>, + mut send_connection_id: Option>, + executor: E, ) -> impl Future { let mut this = self.clone(); async move { @@ -183,14 +189,23 @@ impl Server { let type_name = message.payload_type_name(); log::info!("rpc message received. connection:{}, type:{}", connection_id, type_name); if let Some(handler) = this.handlers.get(&message.payload_type_id()) { - if let Err(err) = (handler)(this.clone(), message).await { - log::error!("rpc message error. connection:{}, type:{}, error:{:?}", connection_id, type_name, err); + let notifications = this.notifications.clone(); + let is_background = message.is_background(); + let handle_message = (handler)(this.clone(), message); + let handle_message = async move { + if let Err(err) = handle_message.await { + log::error!("rpc message error. connection:{}, type:{}, error:{:?}", connection_id, type_name, err); + } else { + log::info!("rpc message handled. connection:{}, type:{}, duration:{:?}", connection_id, type_name, start_time.elapsed()); + } + if let Some(mut notifications) = notifications { + let _ = notifications.send(()).await; + } + }; + if is_background { + executor.spawn_detached(handle_message); } else { - log::info!("rpc message handled. connection:{}, type:{}, duration:{:?}", connection_id, type_name, start_time.elapsed()); - } - - if let Some(mut notifications) = this.notifications.clone() { - let _ = notifications.send(()).await; + handle_message.await; } } else { log::warn!("unhandled message: {}", type_name); @@ -329,6 +344,7 @@ impl Server { .cloned() .collect(), weak: worktree.weak, + next_update_id: share.next_update_id as u64, }) }) .collect(); @@ -467,6 +483,7 @@ impl Server { request.sender_id, entries, diagnostic_summaries, + worktree.next_update_id, )?; broadcast( @@ -485,11 +502,12 @@ impl Server { async fn update_worktree( mut self: Arc, request: TypedEnvelope, - ) -> tide::Result<()> { + ) -> tide::Result { let connection_ids = self.state_mut().update_worktree( request.sender_id, request.payload.project_id, request.payload.worktree_id, + request.payload.id, &request.payload.removed_entries, &request.payload.updated_entries, )?; @@ -499,7 +517,7 @@ impl Server { .forward_send(request.sender_id, connection_id, request.payload.clone()) })?; - Ok(()) + Ok(proto::Ack {}) } async fn update_diagnostic_summary( @@ -767,7 +785,12 @@ impl Server { self: Arc, request: TypedEnvelope, ) -> tide::Result { - let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto); + let user_ids = request + .payload + .user_ids + .into_iter() + .map(UserId::from_proto) + .collect(); let users = self .app_state .db @@ -966,6 +989,12 @@ impl Server { } } +impl Executor for RealExecutor { + fn spawn_detached>(&self, future: F) { + task::spawn(future); + } +} + fn broadcast( sender_id: ConnectionId, receiver_ids: Vec, @@ -1032,6 +1061,7 @@ pub fn add_routes(app: &mut tide::Server>, rpc: &Arc) { addr, user_id, None, + RealExecutor, ) .await; } @@ -1066,15 +1096,17 @@ mod tests { github, AppState, Config, }; use ::rpc::Peer; - use async_std::task; + use collections::BTreeMap; use gpui::{executor, ModelHandle, TestAppContext}; use parking_lot::Mutex; - use postage::{mpsc, watch}; + use postage::{sink::Sink, watch}; use rand::prelude::*; use rpc::PeerId; use serde_json::json; use sqlx::types::time::OffsetDateTime; use std::{ + cell::{Cell, RefCell}, + env, ops::Deref, path::Path, rc::Rc, @@ -1099,7 +1131,7 @@ mod tests { LanguageConfig, LanguageRegistry, LanguageServerConfig, Point, }, lsp, - project::{worktree::WorktreeHandle, DiagnosticSummary, Project, ProjectPath}, + project::{DiagnosticSummary, Project, ProjectPath}, workspace::{Workspace, WorkspaceParams}, }; @@ -1119,7 +1151,7 @@ mod tests { cx_a.foreground().forbid_parking(); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground()).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -1257,7 +1289,7 @@ mod tests { cx_a.foreground().forbid_parking(); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground()).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -1358,7 +1390,7 @@ mod tests { cx_a.foreground().forbid_parking(); // Connect to a server as 3 clients. - let mut server = TestServer::start(cx_a.foreground()).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; let client_c = server.create_client(&mut cx_c, "user_c").await; @@ -1470,11 +1502,6 @@ mod tests { buffer_b.read_with(&cx_b, |buf, _| assert!(!buf.is_dirty())); buffer_c.condition(&cx_c, |buf, _| !buf.is_dirty()).await; - // Ensure worktree observes a/file1's change event *before* the rename occurs, otherwise - // when interpreting the change event it will mistakenly think that the file has been - // deleted (because its path has changed) and will subsequently fail to detect the rename. - worktree_a.flush_fs_events(&cx_a).await; - // Make changes on host's file system, see those changes on guest worktrees. fs.rename( "/a/file1".as_ref(), @@ -1483,6 +1510,7 @@ mod tests { ) .await .unwrap(); + fs.rename("/a/file2".as_ref(), "/a/file3".as_ref(), Default::default()) .await .unwrap(); @@ -1540,7 +1568,7 @@ mod tests { let fs = Arc::new(FakeFs::new(cx_a.background())); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground()).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -1590,14 +1618,12 @@ mod tests { ) .await .unwrap(); - let worktree_b = project_b.update(&mut cx_b, |p, cx| p.worktrees(cx).next().unwrap()); // Open a buffer as client B let buffer_b = project_b .update(&mut cx_b, |p, cx| p.open_buffer((worktree_id, "a.txt"), cx)) .await .unwrap(); - let mtime = buffer_b.read_with(&cx_b, |buf, _| buf.file().unwrap().mtime()); buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "world ", cx)); buffer_b.read_with(&cx_b, |buf, _| { @@ -1609,13 +1635,10 @@ mod tests { .update(&mut cx_b, |buf, cx| buf.save(cx)) .await .unwrap(); - worktree_b - .condition(&cx_b, |_, cx| { - buffer_b.read(cx).file().unwrap().mtime() != mtime - }) + buffer_b + .condition(&cx_b, |buffer_b, _| !buffer_b.is_dirty()) .await; buffer_b.read_with(&cx_b, |buf, _| { - assert!(!buf.is_dirty()); assert!(!buf.has_conflict()); }); @@ -1633,7 +1656,7 @@ mod tests { let fs = Arc::new(FakeFs::new(cx_a.background())); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground()).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -1718,7 +1741,7 @@ mod tests { let fs = Arc::new(FakeFs::new(cx_a.background())); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground()).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -1778,10 +1801,12 @@ mod tests { let buffer_b = cx_b .background() .spawn(project_b.update(&mut cx_b, |p, cx| p.open_buffer((worktree_id, "a.txt"), cx))); - task::yield_now().await; // Edit the buffer as client A while client B is still opening it. - buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "z", cx)); + cx_b.background().simulate_random_delay().await; + buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "X", cx)); + cx_b.background().simulate_random_delay().await; + buffer_a.update(&mut cx_a, |buf, cx| buf.edit([1..1], "Y", cx)); let text = buffer_a.read_with(&cx_a, |buf, _| buf.text()); let buffer_b = buffer_b.await.unwrap(); @@ -1798,7 +1823,7 @@ mod tests { let fs = Arc::new(FakeFs::new(cx_a.background())); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground()).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -1873,7 +1898,7 @@ mod tests { let fs = Arc::new(FakeFs::new(cx_a.background())); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground()).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -1947,8 +1972,7 @@ mod tests { let fs = Arc::new(FakeFs::new(cx_a.background())); // Set up a fake language server. - let (language_server_config, mut fake_language_server) = - LanguageServerConfig::fake(&cx_a).await; + let (language_server_config, mut fake_language_servers) = LanguageServerConfig::fake(); Arc::get_mut(&mut lang_registry) .unwrap() .add(Arc::new(Language::new( @@ -1962,7 +1986,7 @@ mod tests { ))); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground()).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -2017,6 +2041,7 @@ mod tests { .unwrap(); // Simulate a language server reporting errors for a file. + let mut fake_language_server = fake_language_servers.next().await.unwrap(); fake_language_server .notify::(lsp::PublishDiagnosticsParams { uri: lsp::Url::from_file_path("/a/a.rs").unwrap(), @@ -2171,18 +2196,14 @@ mod tests { let fs = Arc::new(FakeFs::new(cx_a.background())); // Set up a fake language server. - let (language_server_config, mut fake_language_server) = - LanguageServerConfig::fake_with_capabilities( - lsp::ServerCapabilities { - completion_provider: Some(lsp::CompletionOptions { - trigger_characters: Some(vec![".".to_string()]), - ..Default::default() - }), - ..Default::default() - }, - &cx_a, - ) - .await; + let (mut language_server_config, mut fake_language_servers) = LanguageServerConfig::fake(); + language_server_config.set_fake_capabilities(lsp::ServerCapabilities { + completion_provider: Some(lsp::CompletionOptions { + trigger_characters: Some(vec![".".to_string()]), + ..Default::default() + }), + ..Default::default() + }); Arc::get_mut(&mut lang_registry) .unwrap() .add(Arc::new(Language::new( @@ -2196,7 +2217,7 @@ mod tests { ))); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground()).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -2264,6 +2285,11 @@ mod tests { ) }); + let mut fake_language_server = fake_language_servers.next().await.unwrap(); + buffer_b + .condition(&cx_b, |buffer, _| !buffer.completion_triggers().is_empty()) + .await; + // Type a completion trigger character as the guest. editor_b.update(&mut cx_b, |editor, cx| { editor.select_ranges([13..13], None, cx); @@ -2273,45 +2299,51 @@ mod tests { // Receive a completion request as the host's language server. // Return some completions from the host's language server. - fake_language_server.handle_request::(|params| { - assert_eq!( - params.text_document_position.text_document.uri, - lsp::Url::from_file_path("/a/main.rs").unwrap(), - ); - assert_eq!( - params.text_document_position.position, - lsp::Position::new(0, 14), - ); + cx_a.foreground().start_waiting(); + fake_language_server + .handle_request::(|params| { + assert_eq!( + params.text_document_position.text_document.uri, + lsp::Url::from_file_path("/a/main.rs").unwrap(), + ); + assert_eq!( + params.text_document_position.position, + lsp::Position::new(0, 14), + ); - Some(lsp::CompletionResponse::Array(vec![ - lsp::CompletionItem { - label: "first_method(…)".into(), - detail: Some("fn(&mut self, B) -> C".into()), - text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit { - new_text: "first_method($1)".to_string(), - range: lsp::Range::new( - lsp::Position::new(0, 14), - lsp::Position::new(0, 14), - ), - })), - insert_text_format: Some(lsp::InsertTextFormat::SNIPPET), - ..Default::default() - }, - lsp::CompletionItem { - label: "second_method(…)".into(), - detail: Some("fn(&mut self, C) -> D".into()), - text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit { - new_text: "second_method()".to_string(), - range: lsp::Range::new( - lsp::Position::new(0, 14), - lsp::Position::new(0, 14), - ), - })), - insert_text_format: Some(lsp::InsertTextFormat::SNIPPET), - ..Default::default() - }, - ])) - }); + Some(lsp::CompletionResponse::Array(vec![ + lsp::CompletionItem { + label: "first_method(…)".into(), + detail: Some("fn(&mut self, B) -> C".into()), + text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit { + new_text: "first_method($1)".to_string(), + range: lsp::Range::new( + lsp::Position::new(0, 14), + lsp::Position::new(0, 14), + ), + })), + insert_text_format: Some(lsp::InsertTextFormat::SNIPPET), + ..Default::default() + }, + lsp::CompletionItem { + label: "second_method(…)".into(), + detail: Some("fn(&mut self, C) -> D".into()), + text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit { + new_text: "second_method()".to_string(), + range: lsp::Range::new( + lsp::Position::new(0, 14), + lsp::Position::new(0, 14), + ), + })), + insert_text_format: Some(lsp::InsertTextFormat::SNIPPET), + ..Default::default() + }, + ])) + }) + .next() + .await + .unwrap(); + cx_a.foreground().finish_waiting(); // Open the buffer on the host. let buffer_a = project_a @@ -2325,9 +2357,10 @@ mod tests { .await; // Confirm a completion on the guest. - editor_b.next_notification(&cx_b).await; + editor_b + .condition(&cx_b, |editor, _| editor.context_menu_visible()) + .await; editor_b.update(&mut cx_b, |editor, cx| { - assert!(editor.context_menu_visible()); editor.confirm_completion(&ConfirmCompletion(Some(0)), cx); assert_eq!(editor.text(cx), "fn main() { a.first_method() }"); }); @@ -2352,22 +2385,17 @@ mod tests { } }); + // The additional edit is applied. buffer_a .condition(&cx_a, |buffer, _| { - buffer.text() == "fn main() { a.first_method() }" + buffer.text() == "use d::SomeTrait;\nfn main() { a.first_method() }" }) .await; - - // The additional edit is applied. buffer_b .condition(&cx_b, |buffer, _| { buffer.text() == "use d::SomeTrait;\nfn main() { a.first_method() }" }) .await; - assert_eq!( - buffer_a.read_with(&cx_a, |buffer, _| buffer.text()), - buffer_b.read_with(&cx_b, |buffer, _| buffer.text()), - ); } #[gpui::test(iterations = 10)] @@ -2377,8 +2405,7 @@ mod tests { let fs = Arc::new(FakeFs::new(cx_a.background())); // Set up a fake language server. - let (language_server_config, mut fake_language_server) = - LanguageServerConfig::fake(&cx_a).await; + let (language_server_config, mut fake_language_servers) = LanguageServerConfig::fake(); Arc::get_mut(&mut lang_registry) .unwrap() .add(Arc::new(Language::new( @@ -2392,7 +2419,7 @@ mod tests { ))); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground()).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -2452,6 +2479,7 @@ mod tests { project.format(HashSet::from_iter([buffer_b.clone()]), true, cx) }); + let mut fake_language_server = fake_language_servers.next().await.unwrap(); fake_language_server.handle_request::(|_| { Some(vec![ lsp::TextEdit { @@ -2494,8 +2522,7 @@ mod tests { .await; // Set up a fake language server. - let (language_server_config, mut fake_language_server) = - LanguageServerConfig::fake(&cx_a).await; + let (language_server_config, mut fake_language_servers) = LanguageServerConfig::fake(); Arc::get_mut(&mut lang_registry) .unwrap() .add(Arc::new(Language::new( @@ -2509,7 +2536,7 @@ mod tests { ))); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground()).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -2560,6 +2587,8 @@ mod tests { // Request the definition of a symbol as the guest. let definitions_1 = project_b.update(&mut cx_b, |p, cx| p.definition(&buffer_b, 23, cx)); + + let mut fake_language_server = fake_language_servers.next().await.unwrap(); fake_language_server.handle_request::(|_| { Some(lsp::GotoDefinitionResponse::Scalar(lsp::Location::new( lsp::Url::from_file_path("/root-2/b.rs").unwrap(), @@ -2640,8 +2669,8 @@ mod tests { .await; // Set up a fake language server. - let (language_server_config, mut fake_language_server) = - LanguageServerConfig::fake(&cx_a).await; + let (language_server_config, mut fake_language_servers) = LanguageServerConfig::fake(); + Arc::get_mut(&mut lang_registry) .unwrap() .add(Arc::new(Language::new( @@ -2655,7 +2684,7 @@ mod tests { ))); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground()).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -2669,6 +2698,7 @@ mod tests { cx, ) }); + let (worktree_a, _) = project_a .update(&mut cx_a, |p, cx| { p.find_or_create_local_worktree("/root", false, cx) @@ -2715,6 +2745,7 @@ mod tests { definitions = project_b.update(&mut cx_b, |p, cx| p.definition(&buffer_b1, 23, cx)); } + let mut fake_language_server = fake_language_servers.next().await.unwrap(); fake_language_server.handle_request::(|_| { Some(lsp::GotoDefinitionResponse::Scalar(lsp::Location::new( lsp::Url::from_file_path("/root/b.rs").unwrap(), @@ -2740,14 +2771,7 @@ mod tests { cx_b.update(|cx| editor::init(cx, &mut path_openers_b)); // Set up a fake language server. - let (language_server_config, mut fake_language_server) = - LanguageServerConfig::fake_with_capabilities( - lsp::ServerCapabilities { - ..Default::default() - }, - &cx_a, - ) - .await; + let (language_server_config, mut fake_language_servers) = LanguageServerConfig::fake(); Arc::get_mut(&mut lang_registry) .unwrap() .add(Arc::new(Language::new( @@ -2761,7 +2785,7 @@ mod tests { ))); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground()).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -2827,6 +2851,8 @@ mod tests { .unwrap() .downcast::() .unwrap(); + + let mut fake_language_server = fake_language_servers.next().await.unwrap(); fake_language_server .handle_request::(|params| { assert_eq!( @@ -2845,58 +2871,62 @@ mod tests { editor.select_ranges([Point::new(1, 31)..Point::new(1, 31)], None, cx); cx.focus(&editor_b); }); - fake_language_server.handle_request::(|params| { - assert_eq!( - params.text_document.uri, - lsp::Url::from_file_path("/a/main.rs").unwrap(), - ); - assert_eq!(params.range.start, lsp::Position::new(1, 31)); - assert_eq!(params.range.end, lsp::Position::new(1, 31)); - - Some(vec![lsp::CodeActionOrCommand::CodeAction( - lsp::CodeAction { - title: "Inline into all callers".to_string(), - edit: Some(lsp::WorkspaceEdit { - changes: Some( - [ - ( - lsp::Url::from_file_path("/a/main.rs").unwrap(), - vec![lsp::TextEdit::new( - lsp::Range::new( - lsp::Position::new(1, 22), - lsp::Position::new(1, 34), - ), - "4".to_string(), - )], - ), - ( - lsp::Url::from_file_path("/a/other.rs").unwrap(), - vec![lsp::TextEdit::new( - lsp::Range::new( - lsp::Position::new(0, 0), - lsp::Position::new(0, 27), - ), - "".to_string(), - )], - ), - ] - .into_iter() - .collect(), - ), - ..Default::default() - }), - data: Some(json!({ - "codeActionParams": { - "range": { - "start": {"line": 1, "column": 31}, - "end": {"line": 1, "column": 31}, + + fake_language_server + .handle_request::(|params| { + assert_eq!( + params.text_document.uri, + lsp::Url::from_file_path("/a/main.rs").unwrap(), + ); + assert_eq!(params.range.start, lsp::Position::new(1, 31)); + assert_eq!(params.range.end, lsp::Position::new(1, 31)); + + Some(vec![lsp::CodeActionOrCommand::CodeAction( + lsp::CodeAction { + title: "Inline into all callers".to_string(), + edit: Some(lsp::WorkspaceEdit { + changes: Some( + [ + ( + lsp::Url::from_file_path("/a/main.rs").unwrap(), + vec![lsp::TextEdit::new( + lsp::Range::new( + lsp::Position::new(1, 22), + lsp::Position::new(1, 34), + ), + "4".to_string(), + )], + ), + ( + lsp::Url::from_file_path("/a/other.rs").unwrap(), + vec![lsp::TextEdit::new( + lsp::Range::new( + lsp::Position::new(0, 0), + lsp::Position::new(0, 27), + ), + "".to_string(), + )], + ), + ] + .into_iter() + .collect(), + ), + ..Default::default() + }), + data: Some(json!({ + "codeActionParams": { + "range": { + "start": {"line": 1, "column": 31}, + "end": {"line": 1, "column": 31}, + } } - } - })), - ..Default::default() - }, - )]) - }); + })), + ..Default::default() + }, + )]) + }) + .next() + .await; // Toggle code actions and wait for them to display. editor_b.update(&mut cx_b, |editor, cx| { @@ -2906,6 +2936,8 @@ mod tests { .condition(&cx_b, |editor, _| editor.context_menu_visible()) .await; + fake_language_server.remove_request_handler::(); + // Confirming the code action will trigger a resolve request. let confirm_action = workspace_b .update(&mut cx_b, |workspace, cx| { @@ -2974,7 +3006,7 @@ mod tests { cx_a.foreground().forbid_parking(); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground()).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; @@ -3113,7 +3145,7 @@ mod tests { async fn test_chat_message_validation(mut cx_a: TestAppContext) { cx_a.foreground().forbid_parking(); - let mut server = TestServer::start(cx_a.foreground()).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let db = &server.app_state.db; @@ -3174,7 +3206,7 @@ mod tests { cx_a.foreground().forbid_parking(); // Connect to a server as 2 clients. - let mut server = TestServer::start(cx_a.foreground()).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; let mut status_b = client_b.status(); @@ -3392,7 +3424,7 @@ mod tests { let fs = Arc::new(FakeFs::new(cx_a.background())); // Connect to a server as 3 clients. - let mut server = TestServer::start(cx_a.foreground()).await; + let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await; let client_a = server.create_client(&mut cx_a, "user_a").await; let client_b = server.create_client(&mut cx_b, "user_b").await; let client_c = server.create_client(&mut cx_c, "user_c").await; @@ -3523,23 +3555,269 @@ mod tests { } } + #[gpui::test(iterations = 100)] + async fn test_random_collaboration(cx: TestAppContext, rng: StdRng) { + cx.foreground().forbid_parking(); + let max_peers = env::var("MAX_PEERS") + .map(|i| i.parse().expect("invalid `MAX_PEERS` variable")) + .unwrap_or(5); + let max_operations = env::var("OPERATIONS") + .map(|i| i.parse().expect("invalid `OPERATIONS` variable")) + .unwrap_or(10); + + let rng = Rc::new(RefCell::new(rng)); + + let mut host_lang_registry = Arc::new(LanguageRegistry::new()); + let guest_lang_registry = Arc::new(LanguageRegistry::new()); + + // Set up a fake language server. + let (mut language_server_config, _fake_language_servers) = LanguageServerConfig::fake(); + language_server_config.set_fake_initializer(|fake_server| { + fake_server.handle_request::(|_| { + Some(lsp::CompletionResponse::Array(vec![lsp::CompletionItem { + text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit { + range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 0)), + new_text: "the-new-text".to_string(), + })), + ..Default::default() + }])) + }); + + fake_server.handle_request::(|_| { + Some(vec![lsp::CodeActionOrCommand::CodeAction( + lsp::CodeAction { + title: "the-code-action".to_string(), + ..Default::default() + }, + )]) + }); + }); + + Arc::get_mut(&mut host_lang_registry) + .unwrap() + .add(Arc::new(Language::new( + LanguageConfig { + name: "Rust".to_string(), + path_suffixes: vec!["rs".to_string()], + language_server: Some(language_server_config), + ..Default::default() + }, + None, + ))); + + let fs = Arc::new(FakeFs::new(cx.background())); + fs.insert_tree( + "/_collab", + json!({ + ".zed.toml": r#"collaborators = ["guest-1", "guest-2", "guest-3", "guest-4", "guest-5"]"# + }), + ) + .await; + + let operations = Rc::new(Cell::new(0)); + let mut server = TestServer::start(cx.foreground(), cx.background()).await; + let mut clients = Vec::new(); + + let mut next_entity_id = 100000; + let mut host_cx = TestAppContext::new( + cx.foreground_platform(), + cx.platform(), + cx.foreground(), + cx.background(), + cx.font_cache(), + next_entity_id, + ); + let host = server.create_client(&mut host_cx, "host").await; + let host_project = host_cx.update(|cx| { + Project::local( + host.client.clone(), + host.user_store.clone(), + host_lang_registry.clone(), + fs.clone(), + cx, + ) + }); + let host_project_id = host_project + .update(&mut host_cx, |p, _| p.next_remote_id()) + .await; + + let (collab_worktree, _) = host_project + .update(&mut host_cx, |project, cx| { + project.find_or_create_local_worktree("/_collab", false, cx) + }) + .await + .unwrap(); + collab_worktree + .read_with(&host_cx, |tree, _| tree.as_local().unwrap().scan_complete()) + .await; + host_project + .update(&mut host_cx, |project, cx| project.share(cx)) + .await + .unwrap(); + + clients.push(cx.foreground().spawn(host.simulate_host( + host_project.clone(), + operations.clone(), + max_operations, + rng.clone(), + host_cx.clone(), + ))); + + while operations.get() < max_operations { + cx.background().simulate_random_delay().await; + if clients.len() < max_peers && rng.borrow_mut().gen_bool(0.05) { + operations.set(operations.get() + 1); + + let guest_id = clients.len(); + log::info!("Adding guest {}", guest_id); + next_entity_id += 100000; + let mut guest_cx = TestAppContext::new( + cx.foreground_platform(), + cx.platform(), + cx.foreground(), + cx.background(), + cx.font_cache(), + next_entity_id, + ); + let guest = server + .create_client(&mut guest_cx, &format!("guest-{}", guest_id)) + .await; + let guest_project = Project::remote( + host_project_id, + guest.client.clone(), + guest.user_store.clone(), + guest_lang_registry.clone(), + fs.clone(), + &mut guest_cx.to_async(), + ) + .await + .unwrap(); + clients.push(cx.foreground().spawn(guest.simulate_guest( + guest_id, + guest_project, + operations.clone(), + max_operations, + rng.clone(), + guest_cx, + ))); + + log::info!("Guest {} added", guest_id); + } + } + + let clients = futures::future::join_all(clients).await; + cx.foreground().run_until_parked(); + + let host_worktree_snapshots = host_project.read_with(&host_cx, |project, cx| { + project + .worktrees(cx) + .map(|worktree| { + let snapshot = worktree.read(cx).snapshot(); + (snapshot.id(), snapshot) + }) + .collect::>() + }); + + for (guest_client, guest_cx) in clients.iter().skip(1) { + let guest_id = guest_client.client.id(); + let worktree_snapshots = + guest_client + .project + .as_ref() + .unwrap() + .read_with(guest_cx, |project, cx| { + project + .worktrees(cx) + .map(|worktree| { + let worktree = worktree.read(cx); + assert!( + !worktree.as_remote().unwrap().has_pending_updates(), + "Guest {} worktree {:?} contains deferred updates", + guest_id, + worktree.id() + ); + (worktree.id(), worktree.snapshot()) + }) + .collect::>() + }); + + assert_eq!( + worktree_snapshots.keys().collect::>(), + host_worktree_snapshots.keys().collect::>(), + "guest {} has different worktrees than the host", + guest_id + ); + for (id, host_snapshot) in &host_worktree_snapshots { + let guest_snapshot = &worktree_snapshots[id]; + assert_eq!( + guest_snapshot.root_name(), + host_snapshot.root_name(), + "guest {} has different root name than the host for worktree {}", + guest_id, + id + ); + assert_eq!( + guest_snapshot.entries(false).collect::>(), + host_snapshot.entries(false).collect::>(), + "guest {} has different snapshot than the host for worktree {}", + guest_id, + id + ); + } + + guest_client + .project + .as_ref() + .unwrap() + .read_with(guest_cx, |project, _| { + assert!( + !project.has_buffered_operations(), + "guest {} has buffered operations ", + guest_id, + ); + }); + + for guest_buffer in &guest_client.buffers { + let buffer_id = guest_buffer.read_with(guest_cx, |buffer, _| buffer.remote_id()); + let host_buffer = host_project.read_with(&host_cx, |project, _| { + project + .shared_buffer(guest_client.peer_id, buffer_id) + .expect(&format!( + "host doest not have buffer for guest:{}, peer:{}, id:{}", + guest_id, guest_client.peer_id, buffer_id + )) + }); + assert_eq!( + guest_buffer.read_with(guest_cx, |buffer, _| buffer.text()), + host_buffer.read_with(&host_cx, |buffer, _| buffer.text()), + "guest {} buffer {} differs from the host's buffer", + guest_id, + buffer_id, + ); + } + } + } + struct TestServer { peer: Arc, app_state: Arc, server: Arc, foreground: Rc, - notifications: mpsc::Receiver<()>, + notifications: mpsc::UnboundedReceiver<()>, connection_killers: Arc>>>>, forbid_connections: Arc, _test_db: TestDb, } impl TestServer { - async fn start(foreground: Rc) -> Self { - let test_db = TestDb::new(); + async fn start( + foreground: Rc, + background: Arc, + ) -> Self { + let test_db = TestDb::fake(background); let app_state = Self::build_app_state(&test_db).await; let peer = Peer::new(); - let notifications = mpsc::channel(128); + let notifications = mpsc::unbounded(); let server = Server::new(app_state.clone(), peer.clone(), Some(notifications.0)); Self { peer, @@ -3561,7 +3839,7 @@ mod tests { let server = self.server.clone(); let connection_killers = self.connection_killers.clone(); let forbid_connections = self.forbid_connections.clone(); - let (connection_id_tx, mut connection_id_rx) = postage::mpsc::channel(16); + let (connection_id_tx, mut connection_id_rx) = mpsc::channel(16); Arc::get_mut(&mut client) .unwrap() @@ -3598,6 +3876,7 @@ mod tests { client_name, user_id, Some(connection_id_tx), + cx.background(), )) .detach(); Ok(client_conn) @@ -3610,6 +3889,9 @@ mod tests { .await .unwrap(); + Channel::init(&client); + Project::init(&client); + let peer_id = PeerId(connection_id_rx.next().await.unwrap().0); let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx)); let mut authed_user = @@ -3620,6 +3902,8 @@ mod tests { client, peer_id, user_store, + project: Default::default(), + buffers: Default::default(), } } @@ -3682,6 +3966,8 @@ mod tests { client: Arc, pub peer_id: PeerId, pub user_store: ModelHandle, + project: Option>, + buffers: HashSet>, } impl Deref for TestClient { @@ -3699,6 +3985,267 @@ mod tests { .read_with(cx, |user_store, _| user_store.current_user().unwrap().id), ) } + + async fn simulate_host( + mut self, + project: ModelHandle, + operations: Rc>, + max_operations: usize, + rng: Rc>, + mut cx: TestAppContext, + ) -> (Self, TestAppContext) { + let fs = project.read_with(&cx, |project, _| project.fs().clone()); + let mut files: Vec = Default::default(); + while operations.get() < max_operations { + operations.set(operations.get() + 1); + + let distribution = rng.borrow_mut().gen_range(0..100); + match distribution { + 0..=20 if !files.is_empty() => { + let mut path = files.choose(&mut *rng.borrow_mut()).unwrap().as_path(); + while let Some(parent_path) = path.parent() { + path = parent_path; + if rng.borrow_mut().gen() { + break; + } + } + + log::info!("Host: find/create local worktree {:?}", path); + project + .update(&mut cx, |project, cx| { + project.find_or_create_local_worktree(path, false, cx) + }) + .await + .unwrap(); + } + 10..=80 if !files.is_empty() => { + let buffer = if self.buffers.is_empty() || rng.borrow_mut().gen() { + let file = files.choose(&mut *rng.borrow_mut()).unwrap(); + let (worktree, path) = project + .update(&mut cx, |project, cx| { + project.find_or_create_local_worktree(file, false, cx) + }) + .await + .unwrap(); + let project_path = + worktree.read_with(&cx, |worktree, _| (worktree.id(), path)); + log::info!("Host: opening path {:?}", project_path); + let buffer = project + .update(&mut cx, |project, cx| { + project.open_buffer(project_path, cx) + }) + .await + .unwrap(); + self.buffers.insert(buffer.clone()); + buffer + } else { + self.buffers + .iter() + .choose(&mut *rng.borrow_mut()) + .unwrap() + .clone() + }; + + if rng.borrow_mut().gen_bool(0.1) { + cx.update(|cx| { + log::info!( + "Host: dropping buffer {:?}", + buffer.read(cx).file().unwrap().full_path(cx) + ); + self.buffers.remove(&buffer); + drop(buffer); + }); + } else { + buffer.update(&mut cx, |buffer, cx| { + log::info!( + "Host: updating buffer {:?}", + buffer.file().unwrap().full_path(cx) + ); + buffer.randomly_edit(&mut *rng.borrow_mut(), 5, cx) + }); + } + } + _ => loop { + let path_component_count = rng.borrow_mut().gen_range(1..=5); + let mut path = PathBuf::new(); + path.push("/"); + for _ in 0..path_component_count { + let letter = rng.borrow_mut().gen_range(b'a'..=b'z'); + path.push(std::str::from_utf8(&[letter]).unwrap()); + } + path.set_extension("rs"); + let parent_path = path.parent().unwrap(); + + log::info!("Host: creating file {:?}", path); + if fs.create_dir(&parent_path).await.is_ok() + && fs.create_file(&path, Default::default()).await.is_ok() + { + files.push(path); + break; + } else { + log::info!("Host: cannot create file"); + } + }, + } + + cx.background().simulate_random_delay().await; + } + + self.project = Some(project); + (self, cx) + } + + pub async fn simulate_guest( + mut self, + guest_id: usize, + project: ModelHandle, + operations: Rc>, + max_operations: usize, + rng: Rc>, + mut cx: TestAppContext, + ) -> (Self, TestAppContext) { + while operations.get() < max_operations { + let buffer = if self.buffers.is_empty() || rng.borrow_mut().gen() { + let worktree = if let Some(worktree) = project.read_with(&cx, |project, cx| { + project + .worktrees(&cx) + .filter(|worktree| { + worktree.read(cx).entries(false).any(|e| e.is_file()) + }) + .choose(&mut *rng.borrow_mut()) + }) { + worktree + } else { + cx.background().simulate_random_delay().await; + continue; + }; + + operations.set(operations.get() + 1); + let project_path = worktree.read_with(&cx, |worktree, _| { + let entry = worktree + .entries(false) + .filter(|e| e.is_file()) + .choose(&mut *rng.borrow_mut()) + .unwrap(); + (worktree.id(), entry.path.clone()) + }); + log::info!("Guest {}: opening path {:?}", guest_id, project_path); + let buffer = project + .update(&mut cx, |project, cx| project.open_buffer(project_path, cx)) + .await + .unwrap(); + self.buffers.insert(buffer.clone()); + buffer + } else { + operations.set(operations.get() + 1); + + self.buffers + .iter() + .choose(&mut *rng.borrow_mut()) + .unwrap() + .clone() + }; + + let choice = rng.borrow_mut().gen_range(0..100); + match choice { + 0..=9 => { + cx.update(|cx| { + log::info!( + "Guest {}: dropping buffer {:?}", + guest_id, + buffer.read(cx).file().unwrap().full_path(cx) + ); + self.buffers.remove(&buffer); + drop(buffer); + }); + } + 10..=19 => { + let completions = project.update(&mut cx, |project, cx| { + log::info!( + "Guest {}: requesting completions for buffer {:?}", + guest_id, + buffer.read(cx).file().unwrap().full_path(cx) + ); + let offset = rng.borrow_mut().gen_range(0..=buffer.read(cx).len()); + project.completions(&buffer, offset, cx) + }); + let completions = cx.background().spawn(async move { + completions.await.expect("completions request failed"); + }); + if rng.borrow_mut().gen_bool(0.3) { + log::info!("Guest {}: detaching completions request", guest_id); + completions.detach(); + } else { + completions.await; + } + } + 20..=29 => { + let code_actions = project.update(&mut cx, |project, cx| { + log::info!( + "Guest {}: requesting code actions for buffer {:?}", + guest_id, + buffer.read(cx).file().unwrap().full_path(cx) + ); + let range = + buffer.read(cx).random_byte_range(0, &mut *rng.borrow_mut()); + project.code_actions(&buffer, range, cx) + }); + let code_actions = cx.background().spawn(async move { + code_actions.await.expect("code actions request failed"); + }); + if rng.borrow_mut().gen_bool(0.3) { + log::info!("Guest {}: detaching code actions request", guest_id); + code_actions.detach(); + } else { + code_actions.await; + } + } + 30..=39 if buffer.read_with(&cx, |buffer, _| buffer.is_dirty()) => { + let (requested_version, save) = buffer.update(&mut cx, |buffer, cx| { + log::info!( + "Guest {}: saving buffer {:?}", + guest_id, + buffer.file().unwrap().full_path(cx) + ); + (buffer.version(), buffer.save(cx)) + }); + let save = cx.spawn(|cx| async move { + let (saved_version, _) = save.await.expect("save request failed"); + buffer.read_with(&cx, |buffer, _| { + assert!(buffer.version().observed_all(&saved_version)); + assert!(saved_version.observed_all(&requested_version)); + }); + }); + if rng.borrow_mut().gen_bool(0.3) { + log::info!("Guest {}: detaching save request", guest_id); + save.detach(); + } else { + save.await; + } + } + _ => { + buffer.update(&mut cx, |buffer, cx| { + log::info!( + "Guest {}: updating buffer {:?}", + guest_id, + buffer.file().unwrap().full_path(cx) + ); + buffer.randomly_edit(&mut *rng.borrow_mut(), 5, cx) + }); + } + } + cx.background().simulate_random_delay().await; + } + + self.project = Some(project); + (self, cx) + } + } + + impl Executor for Arc { + fn spawn_detached>(&self, future: F) { + self.spawn(future).detach(); + } } fn channel_messages(channel: &Channel) -> Vec<(String, String, bool)> { diff --git a/crates/server/src/rpc/store.rs b/crates/server/src/rpc/store.rs index 5cb0a0e1db028631c468b1ab3519114483a12f14..41c611a0972229261c7a80884ccceb309ee97ff4 100644 --- a/crates/server/src/rpc/store.rs +++ b/crates/server/src/rpc/store.rs @@ -43,6 +43,7 @@ pub struct ProjectShare { pub struct WorktreeShare { pub entries: HashMap, pub diagnostic_summaries: BTreeMap, + pub next_update_id: u64, } #[derive(Default)] @@ -403,6 +404,7 @@ impl Store { connection_id: ConnectionId, entries: HashMap, diagnostic_summaries: BTreeMap, + next_update_id: u64, ) -> tide::Result { let project = self .projects @@ -416,6 +418,7 @@ impl Store { worktree.share = Some(WorktreeShare { entries, diagnostic_summaries, + next_update_id, }); Ok(SharedWorktree { authorized_user_ids: project.authorized_user_ids(), @@ -534,6 +537,7 @@ impl Store { connection_id: ConnectionId, project_id: u64, worktree_id: u64, + update_id: u64, removed_entries: &[u64], updated_entries: &[proto::Entry], ) -> tide::Result> { @@ -545,6 +549,11 @@ impl Store { .share .as_mut() .ok_or_else(|| anyhow!("worktree is not shared"))?; + if share.next_update_id != update_id { + return Err(anyhow!("received worktree updates out-of-order"))?; + } + + share.next_update_id = update_id + 1; for entry_id in removed_entries { share.entries.remove(&entry_id); } diff --git a/crates/sum_tree/src/sum_tree.rs b/crates/sum_tree/src/sum_tree.rs index 67c056d858d74b37d29a4490e77348b0781628f8..ea21672b10745445c6de403523c6fb3e4fdcb993 100644 --- a/crates/sum_tree/src/sum_tree.rs +++ b/crates/sum_tree/src/sum_tree.rs @@ -478,6 +478,14 @@ impl SumTree { } } +impl PartialEq for SumTree { + fn eq(&self, other: &Self) -> bool { + self.iter().eq(other.iter()) + } +} + +impl Eq for SumTree {} + impl SumTree { pub fn insert_or_replace(&mut self, item: T, cx: &::Context) -> bool { let mut replaced = false; diff --git a/crates/text/src/text.rs b/crates/text/src/text.rs index 9b7f8dd230e0210b372c0b2ad9f0faf1aef1f004..da003b5d443616b76498a5eb5b51990c66830cf8 100644 --- a/crates/text/src/text.rs +++ b/crates/text/src/text.rs @@ -21,7 +21,7 @@ use operation_queue::OperationQueue; pub use patch::Patch; pub use point::*; pub use point_utf16::*; -use postage::{oneshot, prelude::*}; +use postage::{barrier, oneshot, prelude::*}; #[cfg(any(test, feature = "test-support"))] pub use random_char_iter::*; use rope::TextDimension; @@ -53,6 +53,7 @@ pub struct Buffer { pub lamport_clock: clock::Lamport, subscriptions: Topic, edit_id_resolvers: HashMap>>, + version_barriers: Vec<(clock::Global, barrier::Sender)>, } #[derive(Clone, Debug)] @@ -574,6 +575,7 @@ impl Buffer { lamport_clock, subscriptions: Default::default(), edit_id_resolvers: Default::default(), + version_barriers: Default::default(), } } @@ -835,6 +837,8 @@ impl Buffer { } } } + self.version_barriers + .retain(|(version, _)| !self.snapshot.version().observed_all(version)); Ok(()) } @@ -1305,6 +1309,16 @@ impl Buffer { } } + pub fn wait_for_version(&mut self, version: clock::Global) -> impl Future { + let (tx, mut rx) = barrier::channel(); + if !self.snapshot.version.observed_all(&version) { + self.version_barriers.push((version, tx)); + } + async move { + rx.recv().await; + } + } + fn resolve_edit(&mut self, edit_id: clock::Local) { for mut tx in self .edit_id_resolvers diff --git a/crates/util/src/lib.rs b/crates/util/src/lib.rs index b0c66b005bd0e0deb39962ac37eb5292d63d21bf..919fecf8f9c0097cfceafb4b3a6bfe98e3f91afa 100644 --- a/crates/util/src/lib.rs +++ b/crates/util/src/lib.rs @@ -4,13 +4,14 @@ pub mod test; use futures::Future; use std::{ cmp::Ordering, + ops::AddAssign, pin::Pin, task::{Context, Poll}, }; -pub fn post_inc(value: &mut usize) -> usize { +pub fn post_inc + AddAssign + Copy>(value: &mut T) -> T { let prev = *value; - *value += 1; + *value += T::from(1); prev } diff --git a/crates/workspace/src/workspace.rs b/crates/workspace/src/workspace.rs index 2203e8cbf7d879b2ab3a4ea530a7687586ee6aec..d5d0af6104fe8c4172ce049aca2f7291f9147301 100644 --- a/crates/workspace/src/workspace.rs +++ b/crates/workspace/src/workspace.rs @@ -17,9 +17,9 @@ use gpui::{ json::{self, to_string_pretty, ToJson}, keymap::Binding, platform::{CursorStyle, WindowOptions}, - AnyModelHandle, AnyViewHandle, AppContext, ClipboardItem, Entity, ModelContext, ModelHandle, - MutableAppContext, PathPromptOptions, PromptLevel, RenderContext, Task, View, ViewContext, - ViewHandle, WeakModelHandle, WeakViewHandle, + AnyModelHandle, AnyViewHandle, AppContext, ClipboardItem, Entity, ImageData, ModelContext, + ModelHandle, MutableAppContext, PathPromptOptions, PromptLevel, RenderContext, Task, View, + ViewContext, ViewHandle, WeakModelHandle, WeakViewHandle, }; use language::LanguageRegistry; use log::error; @@ -1139,7 +1139,7 @@ impl Workspace { Flex::row() .with_children(self.render_share_icon(theme, cx)) .with_children(self.render_collaborators(theme, cx)) - .with_child(self.render_avatar( + .with_child(self.render_current_user( self.user_store.read(cx).current_user().as_ref(), self.project.read(cx).replica_id(), theme, @@ -1171,13 +1171,17 @@ impl Workspace { collaborators.sort_unstable_by_key(|collaborator| collaborator.replica_id); collaborators .into_iter() - .map(|collaborator| { - self.render_avatar(Some(&collaborator.user), collaborator.replica_id, theme, cx) + .filter_map(|collaborator| { + Some(self.render_avatar( + collaborator.user.avatar.clone()?, + collaborator.replica_id, + theme, + )) }) .collect() } - fn render_avatar( + fn render_current_user( &self, user: Option<&Arc>, replica_id: ReplicaId, @@ -1185,33 +1189,9 @@ impl Workspace { cx: &mut RenderContext, ) -> ElementBox { if let Some(avatar) = user.and_then(|user| user.avatar.clone()) { - ConstrainedBox::new( - Stack::new() - .with_child( - ConstrainedBox::new( - Image::new(avatar) - .with_style(theme.workspace.titlebar.avatar) - .boxed(), - ) - .with_width(theme.workspace.titlebar.avatar_width) - .aligned() - .boxed(), - ) - .with_child( - AvatarRibbon::new(theme.editor.replica_selection_style(replica_id).cursor) - .constrained() - .with_width(theme.workspace.titlebar.avatar_ribbon.width) - .with_height(theme.workspace.titlebar.avatar_ribbon.height) - .aligned() - .bottom() - .boxed(), - ) - .boxed(), - ) - .with_width(theme.workspace.right_sidebar.width) - .boxed() + self.render_avatar(avatar, replica_id, theme) } else { - MouseEventHandler::new::(0, cx, |state, _| { + MouseEventHandler::new::(cx.view_id(), cx, |state, _| { let style = if state.hovered { &theme.workspace.titlebar.hovered_sign_in_prompt } else { @@ -1229,6 +1209,39 @@ impl Workspace { } } + fn render_avatar( + &self, + avatar: Arc, + replica_id: ReplicaId, + theme: &Theme, + ) -> ElementBox { + ConstrainedBox::new( + Stack::new() + .with_child( + ConstrainedBox::new( + Image::new(avatar) + .with_style(theme.workspace.titlebar.avatar) + .boxed(), + ) + .with_width(theme.workspace.titlebar.avatar_width) + .aligned() + .boxed(), + ) + .with_child( + AvatarRibbon::new(theme.editor.replica_selection_style(replica_id).cursor) + .constrained() + .with_width(theme.workspace.titlebar.avatar_ribbon.width) + .with_height(theme.workspace.titlebar.avatar_ribbon.height) + .aligned() + .bottom() + .boxed(), + ) + .boxed(), + ) + .with_width(theme.workspace.right_sidebar.width) + .boxed() + } + fn render_share_icon(&self, theme: &Theme, cx: &mut RenderContext) -> Option { if self.project().read(cx).is_local() && self.client.user_id().is_some() { enum Share {} diff --git a/crates/zed/src/main.rs b/crates/zed/src/main.rs index dd658255550ee6295733d9ff47899a2fd6cb2915..19f11ebc1ecd4080f9c44e270c61836e33642d17 100644 --- a/crates/zed/src/main.rs +++ b/crates/zed/src/main.rs @@ -53,6 +53,8 @@ fn main() { let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http.clone(), cx)); let mut path_openers = Vec::new(); + project::Project::init(&client); + client::Channel::init(&client); client::init(client.clone(), cx); workspace::init(cx); editor::init(cx, &mut path_openers);