Detailed changes
@@ -325,17 +325,21 @@ impl ChatPanel {
enum SignInPromptLabel {}
Align::new(
- MouseEventHandler::new::<SignInPromptLabel, _, _, _>(0, cx, |mouse_state, _| {
- Label::new(
- "Sign in to use chat".to_string(),
- if mouse_state.hovered {
- theme.chat_panel.hovered_sign_in_prompt.clone()
- } else {
- theme.chat_panel.sign_in_prompt.clone()
- },
- )
- .boxed()
- })
+ MouseEventHandler::new::<SignInPromptLabel, _, _, _>(
+ cx.view_id(),
+ cx,
+ |mouse_state, _| {
+ Label::new(
+ "Sign in to use chat".to_string(),
+ if mouse_state.hovered {
+ theme.chat_panel.hovered_sign_in_prompt.clone()
+ } else {
+ theme.chat_panel.sign_in_prompt.clone()
+ },
+ )
+ .boxed()
+ },
+ )
.with_cursor_style(CursorStyle::PointingHand)
.on_click(move |cx| {
let rpc = rpc.clone();
@@ -4,6 +4,7 @@ use super::{
Client, Status, Subscription, TypedEnvelope,
};
use anyhow::{anyhow, Context, Result};
+use futures::lock::Mutex;
use gpui::{
AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task, WeakModelHandle,
};
@@ -40,6 +41,7 @@ pub struct Channel {
next_pending_message_id: usize,
user_store: ModelHandle<UserStore>,
rpc: Arc<Client>,
+ outgoing_messages_lock: Arc<Mutex<()>>,
rng: StdRng,
_subscription: Subscription,
}
@@ -178,14 +180,17 @@ impl Entity for Channel {
}
impl Channel {
+ pub fn init(rpc: &Arc<Client>) {
+ rpc.add_entity_message_handler(Self::handle_message_sent);
+ }
+
pub fn new(
details: ChannelDetails,
user_store: ModelHandle<UserStore>,
rpc: Arc<Client>,
cx: &mut ModelContext<Self>,
) -> Self {
- let _subscription =
- rpc.add_entity_message_handler(details.id, cx, Self::handle_message_sent);
+ let _subscription = rpc.add_model_for_remote_entity(details.id, cx);
{
let user_store = user_store.clone();
@@ -214,6 +219,7 @@ impl Channel {
details,
user_store,
rpc,
+ outgoing_messages_lock: Default::default(),
messages: Default::default(),
loaded_all_messages: false,
next_pending_message_id: 0,
@@ -259,13 +265,16 @@ impl Channel {
);
let user_store = self.user_store.clone();
let rpc = self.rpc.clone();
+ let outgoing_messages_lock = self.outgoing_messages_lock.clone();
Ok(cx.spawn(|this, mut cx| async move {
+ let outgoing_message_guard = outgoing_messages_lock.lock().await;
let request = rpc.request(proto::SendChannelMessage {
channel_id,
body,
nonce: Some(nonce.into()),
});
let response = request.await?;
+ drop(outgoing_message_guard);
let message = ChannelMessage::from_proto(
response.message.ok_or_else(|| anyhow!("invalid message"))?,
&user_store,
@@ -589,10 +598,14 @@ mod tests {
#[gpui::test]
async fn test_channel_messages(mut cx: TestAppContext) {
+ cx.foreground().forbid_parking();
+
let user_id = 5;
let http_client = FakeHttpClient::new(|_| async move { Ok(Response::new(404)) });
let mut client = Client::new(http_client.clone());
let server = FakeServer::for_client(user_id, &mut client, &cx).await;
+
+ Channel::init(&client);
let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http_client, cx));
let channel_list = cx.add_model(|cx| ChannelList::new(user_store, client.clone(), cx));
@@ -12,7 +12,10 @@ use async_tungstenite::tungstenite::{
http::{Request, StatusCode},
};
use futures::{future::LocalBoxFuture, FutureExt, StreamExt};
-use gpui::{action, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task};
+use gpui::{
+ action, AnyModelHandle, AnyWeakModelHandle, AsyncAppContext, Entity, ModelContext, ModelHandle,
+ MutableAppContext, Task,
+};
use http::HttpClient;
use lazy_static::lazy_static;
use parking_lot::RwLock;
@@ -20,7 +23,7 @@ use postage::watch;
use rand::prelude::*;
use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage};
use std::{
- any::{type_name, TypeId},
+ any::TypeId,
collections::HashMap,
convert::TryFrom,
fmt::Write as _,
@@ -124,19 +127,28 @@ pub enum Status {
ReconnectionError { next_reconnection: Instant },
}
-type ModelHandler = Box<
- dyn Send
- + Sync
- + FnMut(Box<dyn AnyTypedEnvelope>, &AsyncAppContext) -> LocalBoxFuture<'static, Result<()>>,
->;
-
struct ClientState {
credentials: Option<Credentials>,
status: (watch::Sender<Status>, watch::Receiver<Status>),
entity_id_extractors: HashMap<TypeId, Box<dyn Send + Sync + Fn(&dyn AnyTypedEnvelope) -> u64>>,
- model_handlers: HashMap<(TypeId, Option<u64>), Option<ModelHandler>>,
_maintain_connection: Option<Task<()>>,
heartbeat_interval: Duration,
+
+ models_by_entity_type_and_remote_id: HashMap<(TypeId, u64), AnyWeakModelHandle>,
+ models_by_message_type: HashMap<TypeId, AnyModelHandle>,
+ model_types_by_message_type: HashMap<TypeId, TypeId>,
+ message_handlers: HashMap<
+ TypeId,
+ Arc<
+ dyn Send
+ + Sync
+ + Fn(
+ AnyModelHandle,
+ Box<dyn AnyTypedEnvelope>,
+ AsyncAppContext,
+ ) -> LocalBoxFuture<'static, Result<()>>,
+ >,
+ >,
}
#[derive(Clone, Debug)]
@@ -151,23 +163,43 @@ impl Default for ClientState {
credentials: None,
status: watch::channel_with(Status::SignedOut),
entity_id_extractors: Default::default(),
- model_handlers: Default::default(),
_maintain_connection: None,
heartbeat_interval: Duration::from_secs(5),
+ models_by_message_type: Default::default(),
+ models_by_entity_type_and_remote_id: Default::default(),
+ model_types_by_message_type: Default::default(),
+ message_handlers: Default::default(),
}
}
}
-pub struct Subscription {
- client: Weak<Client>,
- id: (TypeId, Option<u64>),
+pub enum Subscription {
+ Entity {
+ client: Weak<Client>,
+ id: (TypeId, u64),
+ },
+ Message {
+ client: Weak<Client>,
+ id: TypeId,
+ },
}
impl Drop for Subscription {
fn drop(&mut self) {
- if let Some(client) = self.client.upgrade() {
- let mut state = client.state.write();
- let _ = state.model_handlers.remove(&self.id).unwrap();
+ match self {
+ Subscription::Entity { client, id } => {
+ if let Some(client) = client.upgrade() {
+ let mut state = client.state.write();
+ let _ = state.models_by_entity_type_and_remote_id.remove(id);
+ }
+ }
+ Subscription::Message { client, id } => {
+ if let Some(client) = client.upgrade() {
+ let mut state = client.state.write();
+ let _ = state.model_types_by_message_type.remove(id);
+ let _ = state.message_handlers.remove(id);
+ }
+ }
}
}
}
@@ -188,6 +220,10 @@ impl Client {
})
}
+ pub fn id(&self) -> usize {
+ self.id
+ }
+
#[cfg(any(test, feature = "test-support"))]
pub fn override_authenticate<F>(&mut self, authenticate: F) -> &mut Self
where
@@ -266,125 +302,118 @@ impl Client {
}
}
- pub fn add_message_handler<T, M, F, Fut>(
+ pub fn add_model_for_remote_entity<T: Entity>(
+ self: &Arc<Self>,
+ remote_id: u64,
+ cx: &mut ModelContext<T>,
+ ) -> Subscription {
+ let handle = AnyModelHandle::from(cx.handle());
+ let mut state = self.state.write();
+ let id = (TypeId::of::<T>(), remote_id);
+ state
+ .models_by_entity_type_and_remote_id
+ .insert(id, handle.downgrade());
+ Subscription::Entity {
+ client: Arc::downgrade(self),
+ id,
+ }
+ }
+
+ pub fn add_message_handler<M, E, H, F>(
self: &Arc<Self>,
- cx: &mut ModelContext<M>,
- mut handler: F,
+ model: ModelHandle<E>,
+ handler: H,
) -> Subscription
where
- T: EnvelopedMessage,
- M: Entity,
- F: 'static
+ M: EnvelopedMessage,
+ E: Entity,
+ H: 'static
+ Send
+ Sync
- + FnMut(ModelHandle<M>, TypedEnvelope<T>, Arc<Self>, AsyncAppContext) -> Fut,
- Fut: 'static + Future<Output = Result<()>>,
+ + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
+ F: 'static + Future<Output = Result<()>>,
{
- let subscription_id = (TypeId::of::<T>(), None);
+ let message_type_id = TypeId::of::<M>();
+
let client = self.clone();
let mut state = self.state.write();
- let model = cx.weak_handle();
- let prev_handler = state.model_handlers.insert(
- subscription_id,
- Some(Box::new(move |envelope, cx| {
- if let Some(model) = model.upgrade(cx) {
- let envelope = envelope.into_any().downcast::<TypedEnvelope<T>>().unwrap();
- handler(model, *envelope, client.clone(), cx.clone()).boxed_local()
- } else {
- async move {
- Err(anyhow!(
- "received message for {:?} but model was dropped",
- type_name::<M>()
- ))
- }
- .boxed_local()
- }
- })),
+ state
+ .models_by_message_type
+ .insert(message_type_id, model.into());
+
+ let prev_handler = state.message_handlers.insert(
+ message_type_id,
+ Arc::new(move |handle, envelope, cx| {
+ let model = handle.downcast::<E>().unwrap();
+ let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
+ handler(model, *envelope, client.clone(), cx).boxed_local()
+ }),
);
if prev_handler.is_some() {
panic!("registered handler for the same message twice");
}
- Subscription {
+ Subscription::Message {
client: Arc::downgrade(self),
- id: subscription_id,
+ id: message_type_id,
}
}
- pub fn add_entity_message_handler<T, M, F, Fut>(
- self: &Arc<Self>,
- remote_id: u64,
- cx: &mut ModelContext<M>,
- mut handler: F,
- ) -> Subscription
+ pub fn add_entity_message_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
where
- T: EntityMessage,
- M: Entity,
- F: 'static
+ M: EntityMessage,
+ E: Entity,
+ H: 'static
+ Send
+ Sync
- + FnMut(ModelHandle<M>, TypedEnvelope<T>, Arc<Self>, AsyncAppContext) -> Fut,
- Fut: 'static + Future<Output = Result<()>>,
+ + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
+ F: 'static + Future<Output = Result<()>>,
{
- let subscription_id = (TypeId::of::<T>(), Some(remote_id));
+ let model_type_id = TypeId::of::<E>();
+ let message_type_id = TypeId::of::<M>();
+
let client = self.clone();
let mut state = self.state.write();
- let model = cx.weak_handle();
+ state
+ .model_types_by_message_type
+ .insert(message_type_id, model_type_id);
state
.entity_id_extractors
- .entry(subscription_id.0)
+ .entry(message_type_id)
.or_insert_with(|| {
Box::new(|envelope| {
let envelope = envelope
.as_any()
- .downcast_ref::<TypedEnvelope<T>>()
+ .downcast_ref::<TypedEnvelope<M>>()
.unwrap();
envelope.payload.remote_entity_id()
})
});
- let prev_handler = state.model_handlers.insert(
- subscription_id,
- Some(Box::new(move |envelope, cx| {
- if let Some(model) = model.upgrade(cx) {
- let envelope = envelope.into_any().downcast::<TypedEnvelope<T>>().unwrap();
- handler(model, *envelope, client.clone(), cx.clone()).boxed_local()
- } else {
- async move {
- Err(anyhow!(
- "received message for {:?} but model was dropped",
- type_name::<M>()
- ))
- }
- .boxed_local()
- }
- })),
+
+ let prev_handler = state.message_handlers.insert(
+ message_type_id,
+ Arc::new(move |handle, envelope, cx| {
+ let model = handle.downcast::<E>().unwrap();
+ let envelope = envelope.into_any().downcast::<TypedEnvelope<M>>().unwrap();
+ handler(model, *envelope, client.clone(), cx).boxed_local()
+ }),
);
if prev_handler.is_some() {
- panic!("registered a handler for the same entity twice")
- }
-
- Subscription {
- client: Arc::downgrade(self),
- id: subscription_id,
+ panic!("registered handler for the same message twice");
}
}
- pub fn add_entity_request_handler<T, M, F, Fut>(
- self: &Arc<Self>,
- remote_id: u64,
- cx: &mut ModelContext<M>,
- mut handler: F,
- ) -> Subscription
+ pub fn add_entity_request_handler<M, E, H, F>(self: &Arc<Self>, handler: H)
where
- T: EntityMessage + RequestMessage,
- M: Entity,
- F: 'static
+ M: EntityMessage + RequestMessage,
+ E: Entity,
+ H: 'static
+ Send
+ Sync
- + FnMut(ModelHandle<M>, TypedEnvelope<T>, Arc<Self>, AsyncAppContext) -> Fut,
- Fut: 'static + Future<Output = Result<T::Response>>,
+ + Fn(ModelHandle<E>, TypedEnvelope<M>, Arc<Self>, AsyncAppContext) -> F,
+ F: 'static + Future<Output = Result<M::Response>>,
{
- self.add_entity_message_handler(remote_id, cx, move |model, envelope, client, cx| {
+ self.add_entity_message_handler(move |model, envelope, client, cx| {
let receipt = envelope.receipt();
let response = handler(model, envelope, client.clone(), cx);
async move {
@@ -500,27 +529,45 @@ impl Client {
while let Some(message) = incoming.next().await {
let mut state = this.state.write();
let payload_type_id = message.payload_type_id();
- let entity_id = if let Some(extract_entity_id) =
- state.entity_id_extractors.get(&message.payload_type_id())
- {
- Some((extract_entity_id)(message.as_ref()))
+ let type_name = message.payload_type_name();
+
+ let model = state
+ .models_by_message_type
+ .get(&payload_type_id)
+ .cloned()
+ .or_else(|| {
+ let model_type_id =
+ *state.model_types_by_message_type.get(&payload_type_id)?;
+ let entity_id = state
+ .entity_id_extractors
+ .get(&message.payload_type_id())
+ .map(|extract_entity_id| {
+ (extract_entity_id)(message.as_ref())
+ })?;
+ let model = state
+ .models_by_entity_type_and_remote_id
+ .get(&(model_type_id, entity_id))?;
+ if let Some(model) = model.upgrade(&cx) {
+ Some(model)
+ } else {
+ state
+ .models_by_entity_type_and_remote_id
+ .remove(&(model_type_id, entity_id));
+ None
+ }
+ });
+
+ let model = if let Some(model) = model {
+ model
} else {
- None
+ log::info!("unhandled message {}", type_name);
+ continue;
};
- let type_name = message.payload_type_name();
-
- let handler_key = (payload_type_id, entity_id);
- if let Some(handler) = state.model_handlers.get_mut(&handler_key) {
- let mut handler = handler.take().unwrap();
+ if let Some(handler) = state.message_handlers.get(&payload_type_id).cloned()
+ {
drop(state); // Avoid deadlocks if the handler interacts with rpc::Client
- let future = (handler)(message, &cx);
- {
- let mut state = this.state.write();
- if state.model_handlers.contains_key(&handler_key) {
- state.model_handlers.insert(handler_key, Some(handler));
- }
- }
+ let future = handler(model, message, cx.clone());
let client_id = this.id;
log::debug!(
@@ -540,7 +587,7 @@ impl Client {
}
Err(error) => {
log::error!(
- "error handling rpc message. client_id:{}, name:{}, error:{}",
+ "error handling message. client_id:{}, name:{}, {}",
client_id,
type_name,
error
@@ -923,39 +970,39 @@ mod tests {
let mut client = Client::new(FakeHttpClient::with_404_response());
let server = FakeServer::for_client(user_id, &mut client, &cx).await;
- let model = cx.add_model(|_| Model { subscription: None });
- let (mut done_tx1, mut done_rx1) = postage::oneshot::channel();
- let (mut done_tx2, mut done_rx2) = postage::oneshot::channel();
- let _subscription1 = model.update(&mut cx, |_, cx| {
- client.add_entity_message_handler(
- 1,
- cx,
- move |_, _: TypedEnvelope<proto::UnshareProject>, _, _| {
- postage::sink::Sink::try_send(&mut done_tx1, ()).unwrap();
- async { Ok(()) }
- },
- )
+ let (done_tx1, mut done_rx1) = smol::channel::unbounded();
+ let (done_tx2, mut done_rx2) = smol::channel::unbounded();
+ client.add_entity_message_handler(
+ move |model: ModelHandle<Model>, _: TypedEnvelope<proto::UnshareProject>, _, cx| {
+ match model.read_with(&cx, |model, _| model.id) {
+ 1 => done_tx1.try_send(()).unwrap(),
+ 2 => done_tx2.try_send(()).unwrap(),
+ _ => unreachable!(),
+ }
+ async { Ok(()) }
+ },
+ );
+ let model1 = cx.add_model(|_| Model {
+ id: 1,
+ subscription: None,
});
- let _subscription2 = model.update(&mut cx, |_, cx| {
- client.add_entity_message_handler(
- 2,
- cx,
- move |_, _: TypedEnvelope<proto::UnshareProject>, _, _| {
- postage::sink::Sink::try_send(&mut done_tx2, ()).unwrap();
- async { Ok(()) }
- },
- )
+ let model2 = cx.add_model(|_| Model {
+ id: 2,
+ subscription: None,
+ });
+ let model3 = cx.add_model(|_| Model {
+ id: 3,
+ subscription: None,
});
+ let _subscription1 =
+ model1.update(&mut cx, |_, cx| client.add_model_for_remote_entity(1, cx));
+ let _subscription2 =
+ model2.update(&mut cx, |_, cx| client.add_model_for_remote_entity(2, cx));
// Ensure dropping a subscription for the same entity type still allows receiving of
// messages for other entity IDs of the same type.
- let subscription3 = model.update(&mut cx, |_, cx| {
- client.add_entity_message_handler(
- 3,
- cx,
- |_, _: TypedEnvelope<proto::UnshareProject>, _, _| async { Ok(()) },
- )
- });
+ let subscription3 =
+ model3.update(&mut cx, |_, cx| client.add_model_for_remote_entity(3, cx));
drop(subscription3);
server.send(proto::UnshareProject { project_id: 1 });
@@ -972,22 +1019,22 @@ mod tests {
let mut client = Client::new(FakeHttpClient::with_404_response());
let server = FakeServer::for_client(user_id, &mut client, &cx).await;
- let model = cx.add_model(|_| Model { subscription: None });
- let (mut done_tx1, _done_rx1) = postage::oneshot::channel();
- let (mut done_tx2, mut done_rx2) = postage::oneshot::channel();
- let subscription1 = model.update(&mut cx, |_, cx| {
- client.add_message_handler(cx, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
- postage::sink::Sink::try_send(&mut done_tx1, ()).unwrap();
+ let model = cx.add_model(|_| Model::default());
+ let (done_tx1, _done_rx1) = smol::channel::unbounded();
+ let (done_tx2, mut done_rx2) = smol::channel::unbounded();
+ let subscription1 = client.add_message_handler(
+ model.clone(),
+ move |_, _: TypedEnvelope<proto::Ping>, _, _| {
+ done_tx1.try_send(()).unwrap();
async { Ok(()) }
- })
- });
+ },
+ );
drop(subscription1);
- let _subscription2 = model.update(&mut cx, |_, cx| {
- client.add_message_handler(cx, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
- postage::sink::Sink::try_send(&mut done_tx2, ()).unwrap();
+ let _subscription2 =
+ client.add_message_handler(model, move |_, _: TypedEnvelope<proto::Ping>, _, _| {
+ done_tx2.try_send(()).unwrap();
async { Ok(()) }
- })
- });
+ });
server.send(proto::Ping {});
done_rx2.next().await.unwrap();
}
@@ -1000,23 +1047,26 @@ mod tests {
let mut client = Client::new(FakeHttpClient::with_404_response());
let server = FakeServer::for_client(user_id, &mut client, &cx).await;
- let model = cx.add_model(|_| Model { subscription: None });
- let (mut done_tx, mut done_rx) = postage::oneshot::channel();
- model.update(&mut cx, |model, cx| {
- model.subscription = Some(client.add_message_handler(
- cx,
- move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
- model.update(&mut cx, |model, _| model.subscription.take());
- postage::sink::Sink::try_send(&mut done_tx, ()).unwrap();
- async { Ok(()) }
- },
- ));
+ let model = cx.add_model(|_| Model::default());
+ let (done_tx, mut done_rx) = smol::channel::unbounded();
+ let subscription = client.add_message_handler(
+ model.clone(),
+ move |model, _: TypedEnvelope<proto::Ping>, _, mut cx| {
+ model.update(&mut cx, |model, _| model.subscription.take());
+ done_tx.try_send(()).unwrap();
+ async { Ok(()) }
+ },
+ );
+ model.update(&mut cx, |model, _| {
+ model.subscription = Some(subscription);
});
server.send(proto::Ping {});
done_rx.next().await.unwrap();
}
+ #[derive(Default)]
struct Model {
+ id: usize,
subscription: Option<Subscription>,
}
@@ -1,25 +1,28 @@
-use super::Client;
-use super::*;
-use crate::http::{HttpClient, Request, Response, ServerResponse};
+use crate::{
+ http::{HttpClient, Request, Response, ServerResponse},
+ Client, Connection, Credentials, EstablishConnectionError, UserStore,
+};
+use anyhow::{anyhow, Result};
use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt};
-use gpui::{ModelHandle, TestAppContext};
+use gpui::{executor, ModelHandle, TestAppContext};
use parking_lot::Mutex;
use rpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope};
-use std::fmt;
-use std::sync::atomic::Ordering::SeqCst;
-use std::sync::{
- atomic::{AtomicBool, AtomicUsize},
- Arc,
-};
+use std::{fmt, rc::Rc, sync::Arc};
pub struct FakeServer {
peer: Arc<Peer>,
- incoming: Mutex<Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>>,
- connection_id: Mutex<Option<ConnectionId>>,
- forbid_connections: AtomicBool,
- auth_count: AtomicUsize,
- access_token: AtomicUsize,
+ state: Arc<Mutex<FakeServerState>>,
user_id: u64,
+ executor: Rc<executor::Foreground>,
+}
+
+#[derive(Default)]
+struct FakeServerState {
+ incoming: Option<BoxStream<'static, Box<dyn proto::AnyTypedEnvelope>>>,
+ connection_id: Option<ConnectionId>,
+ forbid_connections: bool,
+ auth_count: usize,
+ access_token: usize,
}
impl FakeServer {
@@ -27,24 +30,22 @@ impl FakeServer {
client_user_id: u64,
client: &mut Arc<Client>,
cx: &TestAppContext,
- ) -> Arc<Self> {
- let server = Arc::new(Self {
+ ) -> Self {
+ let server = Self {
peer: Peer::new(),
- incoming: Default::default(),
- connection_id: Default::default(),
- forbid_connections: Default::default(),
- auth_count: Default::default(),
- access_token: Default::default(),
+ state: Default::default(),
user_id: client_user_id,
- });
+ executor: cx.foreground(),
+ };
Arc::get_mut(client)
.unwrap()
.override_authenticate({
- let server = server.clone();
+ let state = server.state.clone();
move |cx| {
- server.auth_count.fetch_add(1, SeqCst);
- let access_token = server.access_token.load(SeqCst).to_string();
+ let mut state = state.lock();
+ state.auth_count += 1;
+ let access_token = state.access_token.to_string();
cx.spawn(move |_| async move {
Ok(Credentials {
user_id: client_user_id,
@@ -54,12 +55,32 @@ impl FakeServer {
}
})
.override_establish_connection({
- let server = server.clone();
+ let peer = server.peer.clone();
+ let state = server.state.clone();
move |credentials, cx| {
+ let peer = peer.clone();
+ let state = state.clone();
let credentials = credentials.clone();
- cx.spawn({
- let server = server.clone();
- move |cx| async move { server.establish_connection(&credentials, &cx).await }
+ cx.spawn(move |cx| async move {
+ assert_eq!(credentials.user_id, client_user_id);
+
+ if state.lock().forbid_connections {
+ Err(EstablishConnectionError::Other(anyhow!(
+ "server is forbidding connections"
+ )))?
+ }
+
+ if credentials.access_token != state.lock().access_token.to_string() {
+ Err(EstablishConnectionError::Unauthorized)?
+ }
+
+ let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
+ let (connection_id, io, incoming) = peer.add_connection(server_conn).await;
+ cx.background().spawn(io).detach();
+ let mut state = state.lock();
+ state.connection_id = Some(connection_id);
+ state.incoming = Some(incoming);
+ Ok(client_conn)
})
}
});
@@ -73,49 +94,25 @@ impl FakeServer {
pub fn disconnect(&self) {
self.peer.disconnect(self.connection_id());
- self.connection_id.lock().take();
- self.incoming.lock().take();
- }
-
- async fn establish_connection(
- &self,
- credentials: &Credentials,
- cx: &AsyncAppContext,
- ) -> Result<Connection, EstablishConnectionError> {
- assert_eq!(credentials.user_id, self.user_id);
-
- if self.forbid_connections.load(SeqCst) {
- Err(EstablishConnectionError::Other(anyhow!(
- "server is forbidding connections"
- )))?
- }
-
- if credentials.access_token != self.access_token.load(SeqCst).to_string() {
- Err(EstablishConnectionError::Unauthorized)?
- }
-
- let (client_conn, server_conn, _) = Connection::in_memory(cx.background());
- let (connection_id, io, incoming) = self.peer.add_connection(server_conn).await;
- cx.background().spawn(io).detach();
- *self.incoming.lock() = Some(incoming);
- *self.connection_id.lock() = Some(connection_id);
- Ok(client_conn)
+ let mut state = self.state.lock();
+ state.connection_id.take();
+ state.incoming.take();
}
pub fn auth_count(&self) -> usize {
- self.auth_count.load(SeqCst)
+ self.state.lock().auth_count
}
pub fn roll_access_token(&self) {
- self.access_token.fetch_add(1, SeqCst);
+ self.state.lock().access_token += 1;
}
pub fn forbid_connections(&self) {
- self.forbid_connections.store(true, SeqCst);
+ self.state.lock().forbid_connections = true;
}
pub fn allow_connections(&self) {
- self.forbid_connections.store(false, SeqCst);
+ self.state.lock().forbid_connections = false;
}
pub fn send<T: proto::EnvelopedMessage>(&self, message: T) {
@@ -123,14 +120,17 @@ impl FakeServer {
}
pub async fn receive<M: proto::EnvelopedMessage>(&self) -> Result<TypedEnvelope<M>> {
+ self.executor.start_waiting();
let message = self
- .incoming
+ .state
.lock()
+ .incoming
.as_mut()
.expect("not connected")
.next()
.await
.ok_or_else(|| anyhow!("other half hung up"))?;
+ self.executor.finish_waiting();
let type_name = message.payload_type_name();
Ok(*message
.into_any()
@@ -152,7 +152,7 @@ impl FakeServer {
}
fn connection_id(&self) -> ConnectionId {
- self.connection_id.lock().expect("not connected")
+ self.state.lock().connection_id.expect("not connected")
}
pub async fn build_user_store(
@@ -35,6 +35,7 @@ pub struct ProjectMetadata {
pub struct UserStore {
users: HashMap<u64, Arc<User>>,
+ update_contacts_tx: watch::Sender<Option<proto::UpdateContacts>>,
current_user: watch::Receiver<Option<Arc<User>>>,
contacts: Arc<[Contact]>,
client: Arc<Client>,
@@ -56,23 +57,19 @@ impl UserStore {
cx: &mut ModelContext<Self>,
) -> Self {
let (mut current_user_tx, current_user_rx) = watch::channel();
- let (mut update_contacts_tx, mut update_contacts_rx) =
+ let (update_contacts_tx, mut update_contacts_rx) =
watch::channel::<Option<proto::UpdateContacts>>();
- let update_contacts_subscription = client.add_message_handler(
- cx,
- move |_: ModelHandle<Self>, msg: TypedEnvelope<proto::UpdateContacts>, _, _| {
- *update_contacts_tx.borrow_mut() = Some(msg.payload);
- async move { Ok(()) }
- },
- );
+ let rpc_subscription =
+ client.add_message_handler(cx.handle(), Self::handle_update_contacts);
Self {
users: Default::default(),
current_user: current_user_rx,
contacts: Arc::from([]),
client: client.clone(),
+ update_contacts_tx,
http,
_maintain_contacts: cx.spawn_weak(|this, mut cx| async move {
- let _subscription = update_contacts_subscription;
+ let _subscription = rpc_subscription;
while let Some(message) = update_contacts_rx.recv().await {
if let Some((message, this)) = message.zip(this.upgrade(&cx)) {
this.update(&mut cx, |this, cx| this.update_contacts(message, cx))
@@ -104,6 +101,18 @@ impl UserStore {
}
}
+ async fn handle_update_contacts(
+ this: ModelHandle<Self>,
+ msg: TypedEnvelope<proto::UpdateContacts>,
+ _: Arc<Client>,
+ mut cx: AsyncAppContext,
+ ) -> Result<()> {
+ this.update(&mut cx, |this, _| {
+ *this.update_contacts_tx.borrow_mut() = Some(msg.payload);
+ });
+ Ok(())
+ }
+
fn update_contacts(
&mut self,
message: proto::UpdateContacts,
@@ -216,6 +216,16 @@ impl Global {
}
}
+impl FromIterator<Local> for Global {
+ fn from_iter<T: IntoIterator<Item = Local>>(locals: T) -> Self {
+ let mut result = Self::new();
+ for local in locals {
+ result.observe(local);
+ }
+ result
+ }
+}
+
impl Ord for Lamport {
fn cmp(&self, other: &Self) -> Ordering {
// Use the replica id to break ties between concurrent events.
@@ -57,7 +57,7 @@ impl View for DiagnosticSummary {
let theme = &self.settings.borrow().theme.project_diagnostics;
let in_progress = self.in_progress;
- MouseEventHandler::new::<Tag, _, _, _>(0, cx, |_, _| {
+ MouseEventHandler::new::<Tag, _, _, _>(cx.view_id(), cx, |_, _| {
if in_progress {
Label::new(
"Checking... ".to_string(),
@@ -2264,7 +2264,7 @@ impl Editor {
enum Tag {}
let style = (self.build_settings)(cx).style;
Some(
- MouseEventHandler::new::<Tag, _, _, _>(0, cx, |_, _| {
+ MouseEventHandler::new::<Tag, _, _, _>(cx.view_id(), cx, |_, _| {
Svg::new("icons/zap.svg")
.with_color(style.code_actions_indicator)
.boxed()
@@ -5447,8 +5447,8 @@ mod tests {
use super::*;
use language::LanguageConfig;
use lsp::FakeLanguageServer;
- use postage::prelude::Stream;
use project::{FakeFs, ProjectPath};
+ use smol::stream::StreamExt;
use std::{cell::RefCell, rc::Rc, time::Instant};
use text::Point;
use unindent::Unindent;
@@ -7737,9 +7737,8 @@ mod tests {
}),
..Default::default()
},
- &cx,
- )
- .await;
+ cx.background(),
+ );
let text = "
one
@@ -7792,7 +7791,9 @@ mod tests {
],
)
.await;
- editor.next_notification(&cx).await;
+ editor
+ .condition(&cx, |editor, _| editor.context_menu_visible())
+ .await;
let apply_additional_edits = editor.update(&mut cx, |editor, cx| {
editor.move_down(&MoveDown, cx);
@@ -7856,7 +7857,7 @@ mod tests {
)
.await;
editor
- .condition(&cx, |editor, _| editor.context_menu.is_some())
+ .condition(&cx, |editor, _| editor.context_menu_visible())
.await;
editor.update(&mut cx, |editor, cx| {
@@ -7874,7 +7875,9 @@ mod tests {
],
)
.await;
- editor.next_notification(&cx).await;
+ editor
+ .condition(&cx, |editor, _| editor.context_menu_visible())
+ .await;
let apply_additional_edits = editor.update(&mut cx, |editor, cx| {
let apply_additional_edits = editor
@@ -7912,7 +7915,7 @@ mod tests {
);
Some(lsp::CompletionResponse::Array(
completions
- .into_iter()
+ .iter()
.map(|(range, new_text)| lsp::CompletionItem {
label: new_text.to_string(),
text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit {
@@ -7927,7 +7930,7 @@ mod tests {
.collect(),
))
})
- .recv()
+ .next()
.await;
}
@@ -7937,7 +7940,7 @@ mod tests {
) {
fake.handle_request::<lsp::request::ResolveCompletionItem, _>(move |_| {
lsp::CompletionItem {
- additional_text_edits: edit.map(|(range, new_text)| {
+ additional_text_edits: edit.clone().map(|(range, new_text)| {
vec![lsp::TextEdit::new(
lsp::Range::new(
lsp::Position::new(range.start.row, range.start.column),
@@ -7949,7 +7952,7 @@ mod tests {
..Default::default()
}
})
- .recv()
+ .next()
.await;
}
}
@@ -1,6 +1,6 @@
use std::iter::FromIterator;
-#[derive(Copy, Clone, Debug, Default)]
+#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
pub struct CharBag(u64);
impl CharBag {
@@ -84,6 +84,8 @@ pub trait UpgradeModelHandle {
&self,
handle: &WeakModelHandle<T>,
) -> Option<ModelHandle<T>>;
+
+ fn upgrade_any_model_handle(&self, handle: &AnyWeakModelHandle) -> Option<AnyModelHandle>;
}
pub trait UpgradeViewHandle {
@@ -474,6 +476,10 @@ impl TestAppContext {
self.cx.borrow().cx.font_cache.clone()
}
+ pub fn foreground_platform(&self) -> Rc<platform::test::ForegroundPlatform> {
+ self.foreground_platform.clone()
+ }
+
pub fn platform(&self) -> Arc<dyn platform::Platform> {
self.cx.borrow().cx.platform.clone()
}
@@ -486,6 +492,15 @@ impl TestAppContext {
self.cx.borrow().background().clone()
}
+ pub fn spawn<F, Fut, T>(&self, f: F) -> Task<T>
+ where
+ F: FnOnce(AsyncAppContext) -> Fut,
+ Fut: 'static + Future<Output = T>,
+ T: 'static,
+ {
+ self.cx.borrow_mut().spawn(f)
+ }
+
pub fn simulate_new_path_selection(&self, result: impl FnOnce(PathBuf) -> Option<PathBuf>) {
self.foreground_platform.simulate_new_path_selection(result);
}
@@ -566,7 +581,11 @@ impl UpgradeModelHandle for AsyncAppContext {
&self,
handle: &WeakModelHandle<T>,
) -> Option<ModelHandle<T>> {
- self.0.borrow_mut().upgrade_model_handle(handle)
+ self.0.borrow().upgrade_model_handle(handle)
+ }
+
+ fn upgrade_any_model_handle(&self, handle: &AnyWeakModelHandle) -> Option<AnyModelHandle> {
+ self.0.borrow().upgrade_any_model_handle(handle)
}
}
@@ -685,6 +704,7 @@ pub struct MutableAppContext {
next_entity_id: usize,
next_window_id: usize,
next_subscription_id: usize,
+ frame_count: usize,
subscriptions: Arc<Mutex<HashMap<usize, BTreeMap<usize, SubscriptionCallback>>>>,
observations: Arc<Mutex<HashMap<usize, BTreeMap<usize, ObservationCallback>>>>,
release_observations: Arc<Mutex<HashMap<usize, BTreeMap<usize, ReleaseObservationCallback>>>>,
@@ -729,6 +749,7 @@ impl MutableAppContext {
next_entity_id: 0,
next_window_id: 0,
next_subscription_id: 0,
+ frame_count: 0,
subscriptions: Default::default(),
observations: Default::default(),
release_observations: Default::default(),
@@ -920,6 +941,7 @@ impl MutableAppContext {
window_id: usize,
titlebar_height: f32,
) -> HashMap<usize, ElementBox> {
+ self.start_frame();
let view_ids = self
.views
.keys()
@@ -943,6 +965,10 @@ impl MutableAppContext {
.collect()
}
+ pub(crate) fn start_frame(&mut self) {
+ self.frame_count += 1;
+ }
+
pub fn update<T, F: FnOnce(&mut Self) -> T>(&mut self, callback: F) -> T {
self.pending_flushes += 1;
let result = callback(self);
@@ -1397,7 +1423,12 @@ impl MutableAppContext {
.element_states
.entry(key)
.or_insert_with(|| Box::new(T::default()));
- ElementStateHandle::new(TypeId::of::<Tag>(), id, &self.cx.ref_counts)
+ ElementStateHandle::new(
+ TypeId::of::<Tag>(),
+ id,
+ self.frame_count,
+ &self.cx.ref_counts,
+ )
}
fn remove_dropped_entities(&mut self) {
@@ -1748,6 +1779,10 @@ impl UpgradeModelHandle for MutableAppContext {
) -> Option<ModelHandle<T>> {
self.cx.upgrade_model_handle(handle)
}
+
+ fn upgrade_any_model_handle(&self, handle: &AnyWeakModelHandle) -> Option<AnyModelHandle> {
+ self.cx.upgrade_any_model_handle(handle)
+ }
}
impl UpgradeViewHandle for MutableAppContext {
@@ -1872,6 +1907,19 @@ impl UpgradeModelHandle for AppContext {
None
}
}
+
+ fn upgrade_any_model_handle(&self, handle: &AnyWeakModelHandle) -> Option<AnyModelHandle> {
+ if self.models.contains_key(&handle.model_id) {
+ self.ref_counts.lock().inc_model(handle.model_id);
+ Some(AnyModelHandle {
+ model_id: handle.model_id,
+ model_type: handle.model_type,
+ ref_counts: self.ref_counts.clone(),
+ })
+ } else {
+ None
+ }
+ }
}
impl UpgradeViewHandle for AppContext {
@@ -2264,6 +2312,10 @@ impl<M> UpgradeModelHandle for ModelContext<'_, M> {
) -> Option<ModelHandle<T>> {
self.cx.upgrade_model_handle(handle)
}
+
+ fn upgrade_any_model_handle(&self, handle: &AnyWeakModelHandle) -> Option<AnyModelHandle> {
+ self.cx.upgrade_any_model_handle(handle)
+ }
}
impl<M> Deref for ModelContext<'_, M> {
@@ -2594,6 +2646,10 @@ impl<V> UpgradeModelHandle for ViewContext<'_, V> {
) -> Option<ModelHandle<T>> {
self.cx.upgrade_model_handle(handle)
}
+
+ fn upgrade_any_model_handle(&self, handle: &AnyWeakModelHandle) -> Option<AnyModelHandle> {
+ self.cx.upgrade_any_model_handle(handle)
+ }
}
impl<V> UpgradeViewHandle for ViewContext<'_, V> {
@@ -3274,6 +3330,13 @@ impl AnyModelHandle {
}
}
+ pub fn downgrade(&self) -> AnyWeakModelHandle {
+ AnyWeakModelHandle {
+ model_id: self.model_id,
+ model_type: self.model_type,
+ }
+ }
+
pub fn is<T: Entity>(&self) -> bool {
self.model_type == TypeId::of::<T>()
}
@@ -3290,12 +3353,34 @@ impl<T: Entity> From<ModelHandle<T>> for AnyModelHandle {
}
}
+impl Clone for AnyModelHandle {
+ fn clone(&self) -> Self {
+ self.ref_counts.lock().inc_model(self.model_id);
+ Self {
+ model_id: self.model_id,
+ model_type: self.model_type,
+ ref_counts: self.ref_counts.clone(),
+ }
+ }
+}
+
impl Drop for AnyModelHandle {
fn drop(&mut self) {
self.ref_counts.lock().dec_model(self.model_id);
}
}
+pub struct AnyWeakModelHandle {
+ model_id: usize,
+ model_type: TypeId,
+}
+
+impl AnyWeakModelHandle {
+ pub fn upgrade(&self, cx: &impl UpgradeModelHandle) -> Option<AnyModelHandle> {
+ cx.upgrade_any_model_handle(self)
+ }
+}
+
pub struct WeakViewHandle<T> {
window_id: usize,
view_id: usize,
@@ -3368,8 +3453,15 @@ pub struct ElementStateHandle<T> {
}
impl<T: 'static> ElementStateHandle<T> {
- fn new(tag_type_id: TypeId, id: ElementStateId, ref_counts: &Arc<Mutex<RefCounts>>) -> Self {
- ref_counts.lock().inc_element_state(tag_type_id, id);
+ fn new(
+ tag_type_id: TypeId,
+ id: ElementStateId,
+ frame_id: usize,
+ ref_counts: &Arc<Mutex<RefCounts>>,
+ ) -> Self {
+ ref_counts
+ .lock()
+ .inc_element_state(tag_type_id, id, frame_id);
Self {
value_type: PhantomData,
tag_type_id,
@@ -3508,12 +3600,17 @@ impl Drop for Subscription {
#[derive(Default)]
struct RefCounts {
entity_counts: HashMap<usize, usize>,
- element_state_counts: HashMap<(TypeId, ElementStateId), usize>,
+ element_state_counts: HashMap<(TypeId, ElementStateId), ElementStateRefCount>,
dropped_models: HashSet<usize>,
dropped_views: HashSet<(usize, usize)>,
dropped_element_states: HashSet<(TypeId, ElementStateId)>,
}
+struct ElementStateRefCount {
+ ref_count: usize,
+ frame_id: usize,
+}
+
impl RefCounts {
fn inc_model(&mut self, model_id: usize) {
match self.entity_counts.entry(model_id) {
@@ -3537,11 +3634,21 @@ impl RefCounts {
}
}
- fn inc_element_state(&mut self, tag_type_id: TypeId, id: ElementStateId) {
+ fn inc_element_state(&mut self, tag_type_id: TypeId, id: ElementStateId, frame_id: usize) {
match self.element_state_counts.entry((tag_type_id, id)) {
- Entry::Occupied(mut entry) => *entry.get_mut() += 1,
+ Entry::Occupied(mut entry) => {
+ let entry = entry.get_mut();
+ if entry.frame_id == frame_id || entry.ref_count >= 2 {
+ panic!("used the same element state more than once in the same frame");
+ }
+ entry.ref_count += 1;
+ entry.frame_id = frame_id;
+ }
Entry::Vacant(entry) => {
- entry.insert(1);
+ entry.insert(ElementStateRefCount {
+ ref_count: 1,
+ frame_id,
+ });
self.dropped_element_states.remove(&(tag_type_id, id));
}
}
@@ -3567,9 +3674,9 @@ impl RefCounts {
fn dec_element_state(&mut self, tag_type_id: TypeId, id: ElementStateId) {
let key = (tag_type_id, id);
- let count = self.element_state_counts.get_mut(&key).unwrap();
- *count -= 1;
- if *count == 0 {
+ let entry = self.element_state_counts.get_mut(&key).unwrap();
+ entry.ref_count -= 1;
+ if entry.ref_count == 0 {
self.element_state_counts.remove(&key);
self.dropped_element_states.insert(key);
}
@@ -162,7 +162,6 @@ where
"UniformList does not support being rendered with an unconstrained height"
);
}
- let mut items = Vec::new();
if self.item_count == 0 {
return (
@@ -170,22 +169,27 @@ where
LayoutState {
item_height: 0.,
scroll_max: 0.,
- items,
+ items: Default::default(),
},
);
}
+ let mut items = Vec::new();
let mut size = constraint.max;
let mut item_size;
- if let Some(sample_item_ix) = self.get_width_from_item {
- (self.append_items)(sample_item_ix..sample_item_ix + 1, &mut items, cx);
- let sample_item = items.get_mut(0).unwrap();
+ let sample_item_ix;
+ let mut sample_item;
+ if let Some(sample_ix) = self.get_width_from_item {
+ (self.append_items)(sample_ix..sample_ix + 1, &mut items, cx);
+ sample_item_ix = sample_ix;
+ sample_item = items.pop().unwrap();
item_size = sample_item.layout(constraint, cx);
size.set_x(item_size.x());
} else {
(self.append_items)(0..1, &mut items, cx);
- let first_item = items.first_mut().unwrap();
- item_size = first_item.layout(
+ sample_item_ix = 0;
+ sample_item = items.pop().unwrap();
+ item_size = sample_item.layout(
SizeConstraint::new(
vec2f(constraint.max.x(), 0.0),
vec2f(constraint.max.x(), f32::INFINITY),
@@ -219,8 +223,21 @@ where
self.item_count,
start + (size.y() / item_height).ceil() as usize + 1,
);
- items.clear();
- (self.append_items)(start..end, &mut items, cx);
+
+ if (start..end).contains(&sample_item_ix) {
+ if sample_item_ix > start {
+ (self.append_items)(start..sample_item_ix, &mut items, cx);
+ }
+
+ items.push(sample_item);
+
+ if sample_item_ix < end {
+ (self.append_items)(sample_item_ix + 1..end, &mut items, cx);
+ }
+ } else {
+ (self.append_items)(start..end, &mut items, cx);
+ }
+
for item in &mut items {
let item_size = item.layout(item_constraint, cx);
if item_size.x() > size.x() {
@@ -370,6 +370,13 @@ impl Foreground {
*any_value.downcast().unwrap()
}
+ pub fn run_until_parked(&self) {
+ match self {
+ Self::Deterministic { executor, .. } => executor.run_until_parked(),
+ _ => panic!("this method can only be called on a deterministic executor"),
+ }
+ }
+
pub fn parking_forbidden(&self) -> bool {
match self {
Self::Deterministic { executor, .. } => executor.state.lock().forbid_parking,
@@ -6,9 +6,9 @@ use crate::{
json::{self, ToJson},
platform::Event,
text_layout::TextLayoutCache,
- Action, AnyAction, AnyViewHandle, AssetCache, ElementBox, Entity, FontSystem, ModelHandle,
- ReadModel, ReadView, Scene, UpgradeModelHandle, UpgradeViewHandle, View, ViewHandle,
- WeakModelHandle, WeakViewHandle,
+ Action, AnyAction, AnyModelHandle, AnyViewHandle, AnyWeakModelHandle, AssetCache, ElementBox,
+ Entity, FontSystem, ModelHandle, ReadModel, ReadView, Scene, UpgradeModelHandle,
+ UpgradeViewHandle, View, ViewHandle, WeakModelHandle, WeakViewHandle,
};
use pathfinder_geometry::vector::{vec2f, Vector2F};
use serde_json::json;
@@ -62,6 +62,7 @@ impl Presenter {
}
pub fn invalidate(&mut self, mut invalidation: WindowInvalidation, cx: &mut MutableAppContext) {
+ cx.start_frame();
for view_id in invalidation.removed {
invalidation.updated.remove(&view_id);
self.rendered_views.remove(&view_id);
@@ -81,6 +82,7 @@ impl Presenter {
invalidation: Option<WindowInvalidation>,
cx: &mut MutableAppContext,
) {
+ cx.start_frame();
if let Some(invalidation) = invalidation {
for view_id in invalidation.removed {
self.rendered_views.remove(&view_id);
@@ -278,6 +280,10 @@ impl<'a> UpgradeModelHandle for LayoutContext<'a> {
) -> Option<ModelHandle<T>> {
self.app.upgrade_model_handle(handle)
}
+
+ fn upgrade_any_model_handle(&self, handle: &AnyWeakModelHandle) -> Option<AnyModelHandle> {
+ self.app.upgrade_any_model_handle(handle)
+ }
}
impl<'a> UpgradeViewHandle for LayoutContext<'a> {
@@ -33,6 +33,7 @@ pub fn run_test(
Rc<platform::test::ForegroundPlatform>,
Arc<executor::Deterministic>,
u64,
+ bool,
)),
) {
let is_randomized = num_iterations > 1;
@@ -56,10 +57,8 @@ pub fn run_test(
let font_cache = Arc::new(FontCache::new(font_system));
loop {
- let seed = atomic_seed.load(SeqCst);
- if seed >= starting_seed + num_iterations {
- break;
- }
+ let seed = atomic_seed.fetch_add(1, SeqCst);
+ let is_last_iteration = seed + 1 >= starting_seed + num_iterations;
if is_randomized {
dbg!(seed);
@@ -74,9 +73,19 @@ pub fn run_test(
font_cache.clone(),
0,
);
- cx.update(|cx| test_fn(cx, foreground_platform.clone(), deterministic, seed));
-
- atomic_seed.fetch_add(1, SeqCst);
+ cx.update(|cx| {
+ test_fn(
+ cx,
+ foreground_platform.clone(),
+ deterministic,
+ seed,
+ is_last_iteration,
+ )
+ });
+
+ if is_last_iteration {
+ break;
+ }
}
});
@@ -90,7 +99,7 @@ pub fn run_test(
println!("retrying: attempt {}", retries);
} else {
if is_randomized {
- eprintln!("failing seed: {}", atomic_seed.load(SeqCst));
+ eprintln!("failing seed: {}", atomic_seed.load(SeqCst) - 1);
}
panic::resume_unwind(error);
}
@@ -85,7 +85,10 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
));
}
Some("StdRng") => {
- inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed)));
+ inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed),));
+ }
+ Some("bool") => {
+ inner_fn_args.extend(quote!(is_last_iteration,));
}
_ => {
return TokenStream::from(
@@ -115,7 +118,9 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
#num_iterations as u64,
#starting_seed as u64,
#max_retries,
- &mut |cx, foreground_platform, deterministic, seed| cx.foreground().run(#inner_fn_name(#inner_fn_args))
+ &mut |cx, foreground_platform, deterministic, seed, is_last_iteration| {
+ cx.foreground().run(#inner_fn_name(#inner_fn_args))
+ }
);
}
}
@@ -125,8 +130,14 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
if let FnArg::Typed(arg) = arg {
if let Type::Path(ty) = &*arg.ty {
let last_segment = ty.path.segments.last();
- if let Some("StdRng") = last_segment.map(|s| s.ident.to_string()).as_deref() {
- inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed),));
+ match last_segment.map(|s| s.ident.to_string()).as_deref() {
+ Some("StdRng") => {
+ inner_fn_args.extend(quote!(rand::SeedableRng::seed_from_u64(seed),));
+ }
+ Some("bool") => {
+ inner_fn_args.extend(quote!(is_last_iteration,));
+ }
+ _ => {}
}
} else {
inner_fn_args.extend(quote!(cx,));
@@ -147,7 +158,7 @@ pub fn test(args: TokenStream, function: TokenStream) -> TokenStream {
#num_iterations as u64,
#starting_seed as u64,
#max_retries,
- &mut |cx, _, _, seed| #inner_fn_name(#inner_fn_args)
+ &mut |cx, _, _, seed, is_last_iteration| #inner_fn_name(#inner_fn_args)
);
}
}
@@ -1283,6 +1283,10 @@ impl Buffer {
self.text.wait_for_edits(edit_ids)
}
+ pub fn wait_for_version(&mut self, version: clock::Global) -> impl Future<Output = ()> {
+ self.text.wait_for_version(version)
+ }
+
pub fn set_active_selections(
&mut self,
selections: Arc<[Selection<Anchor>]>,
@@ -17,6 +17,9 @@ use std::{cell::RefCell, ops::Range, path::Path, str, sync::Arc};
use theme::SyntaxTheme;
use tree_sitter::{self, Query};
+#[cfg(any(test, feature = "test-support"))]
+use futures::channel::mpsc;
+
pub use buffer::Operation;
pub use buffer::*;
pub use diagnostic_set::DiagnosticEntry;
@@ -79,7 +82,14 @@ pub struct LanguageServerConfig {
pub disk_based_diagnostics_progress_token: Option<String>,
#[cfg(any(test, feature = "test-support"))]
#[serde(skip)]
- pub fake_server: Option<(Arc<lsp::LanguageServer>, Arc<std::sync::atomic::AtomicBool>)>,
+ fake_config: Option<FakeLanguageServerConfig>,
+}
+
+#[cfg(any(test, feature = "test-support"))]
+struct FakeLanguageServerConfig {
+ servers_tx: mpsc::UnboundedSender<lsp::FakeLanguageServer>,
+ capabilities: lsp::ServerCapabilities,
+ initializer: Option<Box<dyn 'static + Send + Sync + Fn(&mut lsp::FakeLanguageServer)>>,
}
#[derive(Clone, Debug, Deserialize)]
@@ -224,8 +234,27 @@ impl Language {
) -> Result<Option<Arc<lsp::LanguageServer>>> {
if let Some(config) = &self.config.language_server {
#[cfg(any(test, feature = "test-support"))]
- if let Some((server, started)) = &config.fake_server {
- started.store(true, std::sync::atomic::Ordering::SeqCst);
+ if let Some(fake_config) = &config.fake_config {
+ use postage::prelude::Stream;
+
+ let (server, mut fake_server) = lsp::LanguageServer::fake_with_capabilities(
+ fake_config.capabilities.clone(),
+ cx.background().clone(),
+ );
+
+ if let Some(initalizer) = &fake_config.initializer {
+ initalizer(&mut fake_server);
+ }
+
+ let servers_tx = fake_config.servers_tx.clone();
+ let mut initialized = server.capabilities();
+ cx.background()
+ .spawn(async move {
+ while initialized.recv().await.is_none() {}
+ servers_tx.unbounded_send(fake_server).ok();
+ })
+ .detach();
+
return Ok(Some(server.clone()));
}
@@ -357,27 +386,32 @@ impl CompletionLabel {
#[cfg(any(test, feature = "test-support"))]
impl LanguageServerConfig {
- pub async fn fake(cx: &gpui::TestAppContext) -> (Self, lsp::FakeLanguageServer) {
- Self::fake_with_capabilities(Default::default(), cx).await
- }
-
- pub async fn fake_with_capabilities(
- capabilites: lsp::ServerCapabilities,
- cx: &gpui::TestAppContext,
- ) -> (Self, lsp::FakeLanguageServer) {
- let (server, fake) = lsp::LanguageServer::fake_with_capabilities(capabilites, cx).await;
- fake.started
- .store(false, std::sync::atomic::Ordering::SeqCst);
- let started = fake.started.clone();
+ pub fn fake() -> (Self, mpsc::UnboundedReceiver<lsp::FakeLanguageServer>) {
+ let (servers_tx, servers_rx) = mpsc::unbounded();
(
Self {
- fake_server: Some((server, started)),
+ fake_config: Some(FakeLanguageServerConfig {
+ servers_tx,
+ capabilities: Default::default(),
+ initializer: None,
+ }),
disk_based_diagnostics_progress_token: Some("fakeServer/check".to_string()),
..Default::default()
},
- fake,
+ servers_rx,
)
}
+
+ pub fn set_fake_capabilities(&mut self, capabilities: lsp::ServerCapabilities) {
+ self.fake_config.as_mut().unwrap().capabilities = capabilities;
+ }
+
+ pub fn set_fake_initializer(
+ &mut self,
+ initializer: impl 'static + Send + Sync + Fn(&mut lsp::FakeLanguageServer),
+ ) {
+ self.fake_config.as_mut().unwrap().initializer = Some(Box::new(initializer));
+ }
}
impl ToLspPosition for PointUtf16 {
@@ -557,7 +557,7 @@ fn test_autoindent_adjusts_lines_when_only_text_changes(cx: &mut MutableAppConte
#[gpui::test]
async fn test_diagnostics(mut cx: gpui::TestAppContext) {
- let (language_server, mut fake) = lsp::LanguageServer::fake(&cx).await;
+ let (language_server, mut fake) = lsp::LanguageServer::fake(cx.background());
let mut rust_lang = rust_lang();
rust_lang.config.language_server = Some(LanguageServerConfig {
disk_based_diagnostic_sources: HashSet::from_iter(["disk".to_string()]),
@@ -840,7 +840,7 @@ async fn test_diagnostics(mut cx: gpui::TestAppContext) {
#[gpui::test]
async fn test_edits_from_lsp_with_past_version(mut cx: gpui::TestAppContext) {
- let (language_server, mut fake) = lsp::LanguageServer::fake(&cx).await;
+ let (language_server, mut fake) = lsp::LanguageServer::fake(cx.background());
let text = "
fn a() {
@@ -420,7 +420,9 @@ impl LanguageServer {
anyhow!("tried to send a request to a language server that has been shut down")
})
.and_then(|outbound_tx| {
- outbound_tx.try_send(message)?;
+ outbound_tx
+ .try_send(message)
+ .context("failed to write to language server's stdin")?;
Ok(())
});
async move {
@@ -481,43 +483,36 @@ impl Drop for Subscription {
#[cfg(any(test, feature = "test-support"))]
pub struct FakeLanguageServer {
- handlers: Arc<
- Mutex<
- HashMap<
- &'static str,
- Box<dyn Send + FnOnce(usize, &[u8]) -> (Vec<u8>, barrier::Sender)>,
- >,
- >,
- >,
- outgoing_tx: channel::Sender<Vec<u8>>,
- incoming_rx: channel::Receiver<Vec<u8>>,
- pub started: Arc<std::sync::atomic::AtomicBool>,
+ handlers:
+ Arc<Mutex<HashMap<&'static str, Box<dyn Send + Sync + FnMut(usize, &[u8]) -> Vec<u8>>>>>,
+ outgoing_tx: futures::channel::mpsc::UnboundedSender<Vec<u8>>,
+ incoming_rx: futures::channel::mpsc::UnboundedReceiver<Vec<u8>>,
}
#[cfg(any(test, feature = "test-support"))]
impl LanguageServer {
- pub async fn fake(cx: &gpui::TestAppContext) -> (Arc<Self>, FakeLanguageServer) {
- Self::fake_with_capabilities(Default::default(), cx).await
+ pub fn fake(executor: Arc<gpui::executor::Background>) -> (Arc<Self>, FakeLanguageServer) {
+ Self::fake_with_capabilities(Default::default(), executor)
}
- pub async fn fake_with_capabilities(
+ pub fn fake_with_capabilities(
capabilities: ServerCapabilities,
- cx: &gpui::TestAppContext,
+ executor: Arc<gpui::executor::Background>,
) -> (Arc<Self>, FakeLanguageServer) {
let (stdin_writer, stdin_reader) = async_pipe::pipe();
let (stdout_writer, stdout_reader) = async_pipe::pipe();
- let mut fake = FakeLanguageServer::new(cx, stdin_reader, stdout_writer);
- fake.handle_request::<request::Initialize, _>(move |_| InitializeResult {
- capabilities,
- ..Default::default()
+ let mut fake = FakeLanguageServer::new(executor.clone(), stdin_reader, stdout_writer);
+ fake.handle_request::<request::Initialize, _>({
+ let capabilities = capabilities.clone();
+ move |_| InitializeResult {
+ capabilities: capabilities.clone(),
+ ..Default::default()
+ }
});
let server =
- Self::new_internal(stdin_writer, stdout_reader, Path::new("/"), cx.background())
- .unwrap();
- fake.receive_notification::<notification::Initialized>()
- .await;
+ Self::new_internal(stdin_writer, stdout_reader, Path::new("/"), executor).unwrap();
(server, fake)
}
@@ -526,63 +521,59 @@ impl LanguageServer {
#[cfg(any(test, feature = "test-support"))]
impl FakeLanguageServer {
fn new(
- cx: &gpui::TestAppContext,
+ background: Arc<gpui::executor::Background>,
stdin: async_pipe::PipeReader,
stdout: async_pipe::PipeWriter,
) -> Self {
use futures::StreamExt as _;
- let (incoming_tx, incoming_rx) = channel::unbounded();
- let (outgoing_tx, mut outgoing_rx) = channel::unbounded();
+ let (incoming_tx, incoming_rx) = futures::channel::mpsc::unbounded();
+ let (outgoing_tx, mut outgoing_rx) = futures::channel::mpsc::unbounded();
let this = Self {
outgoing_tx: outgoing_tx.clone(),
incoming_rx,
handlers: Default::default(),
- started: Arc::new(std::sync::atomic::AtomicBool::new(true)),
};
// Receive incoming messages
let handlers = this.handlers.clone();
- cx.background()
+ let executor = background.clone();
+ background
.spawn(async move {
let mut buffer = Vec::new();
let mut stdin = smol::io::BufReader::new(stdin);
while Self::receive(&mut stdin, &mut buffer).await.is_ok() {
- if let Ok(request) = serde_json::from_slice::<AnyRequest>(&mut buffer) {
+ executor.simulate_random_delay().await;
+ if let Ok(request) = serde_json::from_slice::<AnyRequest>(&buffer) {
assert_eq!(request.jsonrpc, JSON_RPC_VERSION);
- let handler = handlers.lock().remove(request.method);
- if let Some(handler) = handler {
- let (response, sent) =
- handler(request.id, request.params.get().as_bytes());
+ if let Some(handler) = handlers.lock().get_mut(request.method) {
+ let response = handler(request.id, request.params.get().as_bytes());
log::debug!("handled lsp request. method:{}", request.method);
- outgoing_tx.send(response).await.unwrap();
- drop(sent);
+ outgoing_tx.unbounded_send(response)?;
} else {
log::debug!("unhandled lsp request. method:{}", request.method);
- outgoing_tx
- .send(
- serde_json::to_vec(&AnyResponse {
- id: request.id,
- error: Some(Error {
- message: "no handler".to_string(),
- }),
- result: None,
- })
- .unwrap(),
- )
- .await
- .unwrap();
+ outgoing_tx.unbounded_send(
+ serde_json::to_vec(&AnyResponse {
+ id: request.id,
+ error: Some(Error {
+ message: "no handler".to_string(),
+ }),
+ result: None,
+ })
+ .unwrap(),
+ )?;
}
} else {
- incoming_tx.send(buffer.clone()).await.unwrap();
+ incoming_tx.unbounded_send(buffer.clone())?;
}
}
+ Ok::<_, anyhow::Error>(())
})
.detach();
// Send outgoing messages
- cx.background()
+ background
.spawn(async move {
let mut stdout = smol::io::BufWriter::new(stdout);
while let Some(notification) = outgoing_rx.next().await {
@@ -595,16 +586,13 @@ impl FakeLanguageServer {
}
pub async fn notify<T: notification::Notification>(&mut self, params: T::Params) {
- if !self.started.load(std::sync::atomic::Ordering::SeqCst) {
- panic!("can't simulate an LSP notification before the server has been started");
- }
let message = serde_json::to_vec(&Notification {
jsonrpc: JSON_RPC_VERSION,
method: T::METHOD,
params,
})
.unwrap();
- self.outgoing_tx.send(message).await.unwrap();
+ self.outgoing_tx.unbounded_send(message).unwrap();
}
pub async fn receive_notification<T: notification::Notification>(&mut self) -> T::Params {
@@ -624,15 +612,18 @@ impl FakeLanguageServer {
}
}
- pub fn handle_request<T, F>(&mut self, handler: F) -> barrier::Receiver
+ pub fn handle_request<T, F>(
+ &mut self,
+ mut handler: F,
+ ) -> futures::channel::mpsc::UnboundedReceiver<()>
where
T: 'static + request::Request,
- F: 'static + Send + FnOnce(T::Params) -> T::Result,
+ F: 'static + Send + Sync + FnMut(T::Params) -> T::Result,
{
- let (responded_tx, responded_rx) = barrier::channel();
- let prev_handler = self.handlers.lock().insert(
+ let (responded_tx, responded_rx) = futures::channel::mpsc::unbounded();
+ self.handlers.lock().insert(
T::METHOD,
- Box::new(|id, params| {
+ Box::new(move |id, params| {
let result = handler(serde_json::from_slice::<T::Params>(params).unwrap());
let result = serde_json::to_string(&result).unwrap();
let result = serde_json::from_str::<&RawValue>(&result).unwrap();
@@ -641,18 +632,20 @@ impl FakeLanguageServer {
error: None,
result: Some(result),
};
- (serde_json::to_vec(&response).unwrap(), responded_tx)
+ responded_tx.unbounded_send(()).ok();
+ serde_json::to_vec(&response).unwrap()
}),
);
- if prev_handler.is_some() {
- panic!(
- "registered a new handler for LSP method '{}' before the previous handler was called",
- T::METHOD
- );
- }
responded_rx
}
+ pub fn remove_request_handler<T>(&mut self)
+ where
+ T: 'static + request::Request,
+ {
+ self.handlers.lock().remove(T::METHOD);
+ }
+
pub async fn start_progress(&mut self, token: impl Into<String>) {
self.notify::<notification::Progress>(ProgressParams {
token: NumberOrString::String(token.into()),
@@ -777,7 +770,7 @@ mod tests {
#[gpui::test]
async fn test_fake(cx: TestAppContext) {
- let (server, mut fake) = LanguageServer::fake(&cx).await;
+ let (server, mut fake) = LanguageServer::fake(cx.background());
let (message_tx, message_rx) = channel::unbounded();
let (diagnostics_tx, diagnostics_rx) = channel::unbounded();
@@ -2,7 +2,7 @@ pub mod fs;
mod ignore;
pub mod worktree;
-use anyhow::{anyhow, Result};
+use anyhow::{anyhow, Context, Result};
use client::{proto, Client, PeerId, TypedEnvelope, User, UserStore};
use clock::ReplicaId;
use collections::{hash_map, HashMap, HashSet};
@@ -10,17 +10,17 @@ use futures::Future;
use fuzzy::{PathMatch, PathMatchCandidate, PathMatchCandidateSet};
use gpui::{
AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext, Task,
- WeakModelHandle,
+ UpgradeModelHandle, WeakModelHandle,
};
use language::{
point_from_lsp,
proto::{deserialize_anchor, serialize_anchor},
range_from_lsp, AnchorRangeExt, Bias, Buffer, CodeAction, Completion, CompletionLabel,
- Diagnostic, DiagnosticEntry, File as _, Language, LanguageRegistry, PointUtf16, ToLspPosition,
- ToOffset, ToPointUtf16, Transaction,
+ Diagnostic, DiagnosticEntry, File as _, Language, LanguageRegistry, Operation, PointUtf16,
+ ToLspPosition, ToOffset, ToPointUtf16, Transaction,
};
use lsp::{DiagnosticSeverity, LanguageServer};
-use postage::{prelude::Stream, watch};
+use postage::{broadcast, prelude::Stream, sink::Sink, watch};
use smol::block_on;
use std::{
convert::TryInto,
@@ -46,7 +46,8 @@ pub struct Project {
collaborators: HashMap<PeerId, Collaborator>,
subscriptions: Vec<client::Subscription>,
language_servers_with_diagnostics_running: isize,
- open_buffers: HashMap<usize, WeakModelHandle<Buffer>>,
+ open_buffers: HashMap<u64, OpenBuffer>,
+ opened_buffer: broadcast::Sender<()>,
loading_buffers: HashMap<
ProjectPath,
postage::watch::Receiver<Option<Result<ModelHandle<Buffer>, Arc<anyhow::Error>>>>,
@@ -54,6 +55,11 @@ pub struct Project {
shared_buffers: HashMap<PeerId, HashMap<u64, ModelHandle<Buffer>>>,
}
+enum OpenBuffer {
+ Loaded(WeakModelHandle<Buffer>),
+ Loading(Vec<Operation>),
+}
+
enum WorktreeHandle {
Strong(ModelHandle<Worktree>),
Weak(WeakModelHandle<Worktree>),
@@ -155,6 +161,31 @@ pub struct ProjectEntry {
}
impl Project {
+ pub fn init(client: &Arc<Client>) {
+ client.add_entity_message_handler(Self::handle_add_collaborator);
+ client.add_entity_message_handler(Self::handle_buffer_reloaded);
+ client.add_entity_message_handler(Self::handle_buffer_saved);
+ client.add_entity_message_handler(Self::handle_close_buffer);
+ client.add_entity_message_handler(Self::handle_disk_based_diagnostics_updated);
+ client.add_entity_message_handler(Self::handle_disk_based_diagnostics_updating);
+ client.add_entity_message_handler(Self::handle_remove_collaborator);
+ client.add_entity_message_handler(Self::handle_share_worktree);
+ client.add_entity_message_handler(Self::handle_unregister_worktree);
+ client.add_entity_message_handler(Self::handle_unshare_project);
+ client.add_entity_message_handler(Self::handle_update_buffer_file);
+ client.add_entity_message_handler(Self::handle_update_buffer);
+ client.add_entity_message_handler(Self::handle_update_diagnostic_summary);
+ client.add_entity_message_handler(Self::handle_update_worktree);
+ client.add_entity_request_handler(Self::handle_apply_additional_edits_for_completion);
+ client.add_entity_request_handler(Self::handle_apply_code_action);
+ client.add_entity_request_handler(Self::handle_format_buffers);
+ client.add_entity_request_handler(Self::handle_get_code_actions);
+ client.add_entity_request_handler(Self::handle_get_completions);
+ client.add_entity_request_handler(Self::handle_get_definition);
+ client.add_entity_request_handler(Self::handle_open_buffer);
+ client.add_entity_request_handler(Self::handle_save_buffer);
+ }
+
pub fn local(
client: Arc<Client>,
user_store: ModelHandle<UserStore>,
@@ -216,6 +247,7 @@ impl Project {
remote_id_rx,
_maintain_remote_id_task,
},
+ opened_buffer: broadcast::channel(1).0,
subscriptions: Vec::new(),
active_entry: None,
languages,
@@ -254,70 +286,19 @@ impl Project {
load_task.detach();
}
- let user_ids = response
- .collaborators
- .iter()
- .map(|peer| peer.user_id)
- .collect();
- user_store
- .update(cx, |user_store, cx| user_store.load_users(user_ids, cx))
- .await?;
- let mut collaborators = HashMap::default();
- for message in response.collaborators {
- let collaborator = Collaborator::from_proto(message, &user_store, cx).await?;
- collaborators.insert(collaborator.peer_id, collaborator);
- }
-
- Ok(cx.add_model(|cx| {
+ let this = cx.add_model(|cx| {
let mut this = Self {
worktrees: Vec::new(),
open_buffers: Default::default(),
loading_buffers: Default::default(),
+ opened_buffer: broadcast::channel(1).0,
shared_buffers: Default::default(),
active_entry: None,
- collaborators,
+ collaborators: Default::default(),
languages,
- user_store,
+ user_store: user_store.clone(),
fs,
- subscriptions: vec![
- client.add_entity_message_handler(remote_id, cx, Self::handle_unshare_project),
- client.add_entity_message_handler(remote_id, cx, Self::handle_add_collaborator),
- client.add_entity_message_handler(
- remote_id,
- cx,
- Self::handle_remove_collaborator,
- ),
- client.add_entity_message_handler(remote_id, cx, Self::handle_share_worktree),
- client.add_entity_message_handler(
- remote_id,
- cx,
- Self::handle_unregister_worktree,
- ),
- client.add_entity_message_handler(remote_id, cx, Self::handle_update_worktree),
- client.add_entity_message_handler(
- remote_id,
- cx,
- Self::handle_update_diagnostic_summary,
- ),
- client.add_entity_message_handler(
- remote_id,
- cx,
- Self::handle_disk_based_diagnostics_updating,
- ),
- client.add_entity_message_handler(
- remote_id,
- cx,
- Self::handle_disk_based_diagnostics_updated,
- ),
- client.add_entity_message_handler(remote_id, cx, Self::handle_update_buffer),
- client.add_entity_message_handler(
- remote_id,
- cx,
- Self::handle_update_buffer_file,
- ),
- client.add_entity_message_handler(remote_id, cx, Self::handle_buffer_reloaded),
- client.add_entity_message_handler(remote_id, cx, Self::handle_buffer_saved),
- ],
+ subscriptions: vec![client.add_model_for_remote_entity(remote_id, cx)],
client,
client_state: ProjectClientState::Remote {
sharing_has_stopped: false,
@@ -331,7 +312,27 @@ impl Project {
this.add_worktree(&worktree, cx);
}
this
- }))
+ });
+
+ let user_ids = response
+ .collaborators
+ .iter()
+ .map(|peer| peer.user_id)
+ .collect();
+ user_store
+ .update(cx, |user_store, cx| user_store.load_users(user_ids, cx))
+ .await?;
+ let mut collaborators = HashMap::default();
+ for message in response.collaborators {
+ let collaborator = Collaborator::from_proto(message, &user_store, cx).await?;
+ collaborators.insert(collaborator.peer_id, collaborator);
+ }
+
+ this.update(cx, |this, _| {
+ this.collaborators = collaborators;
+ });
+
+ Ok(this)
}
#[cfg(any(test, feature = "test-support"))]
@@ -343,6 +344,25 @@ impl Project {
cx.update(|cx| Project::local(client, user_store, languages, fs, cx))
}
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn shared_buffer(&self, peer_id: PeerId, remote_id: u64) -> Option<ModelHandle<Buffer>> {
+ self.shared_buffers
+ .get(&peer_id)
+ .and_then(|buffers| buffers.get(&remote_id))
+ .cloned()
+ }
+
+ #[cfg(any(test, feature = "test-support"))]
+ pub fn has_buffered_operations(&self) -> bool {
+ self.open_buffers
+ .values()
+ .any(|buffer| matches!(buffer, OpenBuffer::Loading(_)))
+ }
+
+ pub fn fs(&self) -> &Arc<dyn Fs> {
+ &self.fs
+ }
+
fn set_remote_id(&mut self, remote_id: Option<u64>, cx: &mut ModelContext<Self>) {
if let ProjectClientState::Local { remote_id_tx, .. } = &mut self.client_state {
*remote_id_tx.borrow_mut() = remote_id;
@@ -350,27 +370,8 @@ impl Project {
self.subscriptions.clear();
if let Some(remote_id) = remote_id {
- let client = &self.client;
- self.subscriptions.extend([
- client.add_entity_request_handler(remote_id, cx, Self::handle_open_buffer),
- client.add_entity_message_handler(remote_id, cx, Self::handle_close_buffer),
- client.add_entity_message_handler(remote_id, cx, Self::handle_add_collaborator),
- client.add_entity_message_handler(remote_id, cx, Self::handle_remove_collaborator),
- client.add_entity_message_handler(remote_id, cx, Self::handle_update_worktree),
- client.add_entity_message_handler(remote_id, cx, Self::handle_update_buffer),
- client.add_entity_request_handler(remote_id, cx, Self::handle_save_buffer),
- client.add_entity_message_handler(remote_id, cx, Self::handle_buffer_saved),
- client.add_entity_request_handler(remote_id, cx, Self::handle_format_buffers),
- client.add_entity_request_handler(remote_id, cx, Self::handle_get_completions),
- client.add_entity_request_handler(
- remote_id,
- cx,
- Self::handle_apply_additional_edits_for_completion,
- ),
- client.add_entity_request_handler(remote_id, cx, Self::handle_get_code_actions),
- client.add_entity_request_handler(remote_id, cx, Self::handle_apply_code_action),
- client.add_entity_request_handler(remote_id, cx, Self::handle_get_definition),
- ]);
+ self.subscriptions
+ .push(self.client.add_model_for_remote_entity(remote_id, cx));
}
}
@@ -521,6 +522,10 @@ impl Project {
}
}
+ pub fn is_remote(&self) -> bool {
+ !self.is_local()
+ }
+
pub fn open_buffer(
&mut self,
path: impl Into<ProjectPath>,
@@ -560,6 +565,11 @@ impl Project {
*tx.borrow_mut() = Some(this.update(&mut cx, |this, _| {
// Record the fact that the buffer is no longer loading.
this.loading_buffers.remove(&project_path);
+ if this.loading_buffers.is_empty() {
+ this.open_buffers
+ .retain(|_, buffer| matches!(buffer, OpenBuffer::Loaded(_)))
+ }
+
let buffer = load_result.map_err(Arc::new)?;
Ok(buffer)
}));
@@ -626,6 +636,7 @@ impl Project {
.await?;
let buffer = response.buffer.ok_or_else(|| anyhow!("missing buffer"))?;
this.update(&mut cx, |this, cx| this.deserialize_buffer(buffer, cx))
+ .await
})
}
@@ -737,12 +748,15 @@ impl Project {
worktree: Option<&ModelHandle<Worktree>>,
cx: &mut ModelContext<Self>,
) -> Result<()> {
- if self
- .open_buffers
- .insert(buffer.read(cx).remote_id() as usize, buffer.downgrade())
- .is_some()
- {
- return Err(anyhow!("registered the same buffer twice"));
+ match self.open_buffers.insert(
+ buffer.read(cx).remote_id(),
+ OpenBuffer::Loaded(buffer.downgrade()),
+ ) {
+ None => {}
+ Some(OpenBuffer::Loading(operations)) => {
+ buffer.update(cx, |buffer, cx| buffer.apply_ops(operations, cx))?
+ }
+ Some(OpenBuffer::Loaded(_)) => Err(anyhow!("registered the same buffer twice"))?,
}
self.assign_language_to_buffer(&buffer, worktree, cx);
Ok(())
@@ -1263,29 +1277,27 @@ impl Project {
};
cx.spawn(|this, mut cx| async move {
let response = client.request(request).await?;
- this.update(&mut cx, |this, cx| {
- let mut definitions = Vec::new();
- for definition in response.definitions {
- let target_buffer = this.deserialize_buffer(
- definition.buffer.ok_or_else(|| anyhow!("missing buffer"))?,
- cx,
- )?;
- let target_start = definition
- .target_start
- .and_then(deserialize_anchor)
- .ok_or_else(|| anyhow!("missing target start"))?;
- let target_end = definition
- .target_end
- .and_then(deserialize_anchor)
- .ok_or_else(|| anyhow!("missing target end"))?;
- definitions.push(Definition {
- target_buffer,
- target_range: target_start..target_end,
- })
- }
+ let mut definitions = Vec::new();
+ for definition in response.definitions {
+ let buffer = definition.buffer.ok_or_else(|| anyhow!("missing buffer"))?;
+ let target_buffer = this
+ .update(&mut cx, |this, cx| this.deserialize_buffer(buffer, cx))
+ .await?;
+ let target_start = definition
+ .target_start
+ .and_then(deserialize_anchor)
+ .ok_or_else(|| anyhow!("missing target start"))?;
+ let target_end = definition
+ .target_end
+ .and_then(deserialize_anchor)
+ .ok_or_else(|| anyhow!("missing target end"))?;
+ definitions.push(Definition {
+ target_buffer,
+ target_range: target_start..target_end,
+ })
+ }
- Ok(definitions)
- })
+ Ok(definitions)
})
} else {
Task::ready(Ok(Default::default()))
@@ -1324,18 +1336,19 @@ impl Project {
cx.spawn(|_, cx| async move {
let completions = lang_server
- .request::<lsp::request::Completion>(lsp::CompletionParams {
- text_document_position: lsp::TextDocumentPositionParams::new(
- lsp::TextDocumentIdentifier::new(
- lsp::Url::from_file_path(buffer_abs_path).unwrap(),
+ .request::<lsp::request::Completion>(lsp::CompletionParams {
+ text_document_position: lsp::TextDocumentPositionParams::new(
+ lsp::TextDocumentIdentifier::new(
+ lsp::Url::from_file_path(buffer_abs_path).unwrap(),
+ ),
+ position.to_lsp_position(),
),
- position.to_lsp_position(),
- ),
- context: Default::default(),
- work_done_progress_params: Default::default(),
- partial_result_params: Default::default(),
- })
- .await?;
+ context: Default::default(),
+ work_done_progress_params: Default::default(),
+ partial_result_params: Default::default(),
+ })
+ .await
+ .context("lsp completion request failed")?;
let completions = if let Some(completions) = completions {
match completions {
@@ -1347,41 +1360,56 @@ impl Project {
};
source_buffer_handle.read_with(&cx, |this, _| {
- Ok(completions.into_iter().filter_map(|lsp_completion| {
- let (old_range, new_text) = match lsp_completion.text_edit.as_ref()? {
- lsp::CompletionTextEdit::Edit(edit) => (range_from_lsp(edit.range), edit.new_text.clone()),
- lsp::CompletionTextEdit::InsertAndReplace(_) => {
- log::info!("received an insert and replace completion but we don't yet support that");
- return None
- },
- };
-
- let clipped_start = this.clip_point_utf16(old_range.start, Bias::Left);
- let clipped_end = this.clip_point_utf16(old_range.end, Bias::Left) ;
- if clipped_start == old_range.start && clipped_end == old_range.end {
- Some(Completion {
- old_range: this.anchor_before(old_range.start)..this.anchor_after(old_range.end),
- new_text,
- label: language.as_ref().and_then(|l| l.label_for_completion(&lsp_completion)).unwrap_or_else(|| CompletionLabel::plain(&lsp_completion)),
- lsp_completion,
- })
- } else {
- None
- }
- }).collect())
+ Ok(completions
+ .into_iter()
+ .filter_map(|lsp_completion| {
+ let (old_range, new_text) = match lsp_completion.text_edit.as_ref()? {
+ lsp::CompletionTextEdit::Edit(edit) => {
+ (range_from_lsp(edit.range), edit.new_text.clone())
+ }
+ lsp::CompletionTextEdit::InsertAndReplace(_) => {
+ log::info!("unsupported insert/replace completion");
+ return None;
+ }
+ };
+
+ let clipped_start = this.clip_point_utf16(old_range.start, Bias::Left);
+ let clipped_end = this.clip_point_utf16(old_range.end, Bias::Left);
+ if clipped_start == old_range.start && clipped_end == old_range.end {
+ Some(Completion {
+ old_range: this.anchor_before(old_range.start)
+ ..this.anchor_after(old_range.end),
+ new_text,
+ label: language
+ .as_ref()
+ .and_then(|l| l.label_for_completion(&lsp_completion))
+ .unwrap_or_else(|| CompletionLabel::plain(&lsp_completion)),
+ lsp_completion,
+ })
+ } else {
+ None
+ }
+ })
+ .collect())
})
-
})
} else if let Some(project_id) = self.remote_id() {
let rpc = self.client.clone();
- cx.foreground().spawn(async move {
- let response = rpc
- .request(proto::GetCompletions {
- project_id,
- buffer_id,
- position: Some(language::proto::serialize_anchor(&anchor)),
+ let message = proto::GetCompletions {
+ project_id,
+ buffer_id,
+ position: Some(language::proto::serialize_anchor(&anchor)),
+ version: (&source_buffer.version()).into(),
+ };
+ cx.spawn_weak(|_, mut cx| async move {
+ let response = rpc.request(message).await?;
+
+ source_buffer_handle
+ .update(&mut cx, |buffer, _| {
+ buffer.wait_for_version(response.version.into())
})
- .await?;
+ .await;
+
response
.completions
.into_iter()
@@ -1550,7 +1578,7 @@ impl Project {
})
} else if let Some(project_id) = self.remote_id() {
let rpc = self.client.clone();
- cx.foreground().spawn(async move {
+ cx.spawn_weak(|_, mut cx| async move {
let response = rpc
.request(proto::GetCodeActions {
project_id,
@@ -1559,6 +1587,13 @@ impl Project {
end: Some(language::proto::serialize_anchor(&range.end)),
})
.await?;
+
+ buffer_handle
+ .update(&mut cx, |buffer, _| {
+ buffer.wait_for_version(response.version.into())
+ })
+ .await;
+
response
.actions
.into_iter()
@@ -2124,9 +2159,9 @@ impl Project {
this.update(&mut cx, |this, cx| {
let worktree_id = WorktreeId::from_proto(envelope.payload.worktree_id);
if let Some(worktree) = this.worktree_for_id(worktree_id, cx) {
- worktree.update(cx, |worktree, cx| {
+ worktree.update(cx, |worktree, _| {
let worktree = worktree.as_remote_mut().unwrap();
- worktree.update_from_remote(envelope, cx)
+ worktree.update_from_remote(envelope)
})?;
}
Ok(())
@@ -2188,15 +2223,26 @@ impl Project {
) -> Result<()> {
this.update(&mut cx, |this, cx| {
let payload = envelope.payload.clone();
- let buffer_id = payload.buffer_id as usize;
+ let buffer_id = payload.buffer_id;
let ops = payload
.operations
.into_iter()
.map(|op| language::proto::deserialize_operation(op))
.collect::<Result<Vec<_>, _>>()?;
- if let Some(buffer) = this.open_buffers.get_mut(&buffer_id) {
- if let Some(buffer) = buffer.upgrade(cx) {
- buffer.update(cx, |buffer, cx| buffer.apply_ops(ops, cx))?;
+ let is_remote = this.is_remote();
+ match this.open_buffers.entry(buffer_id) {
+ hash_map::Entry::Occupied(mut e) => match e.get_mut() {
+ OpenBuffer::Loaded(buffer) => {
+ if let Some(buffer) = buffer.upgrade(cx) {
+ buffer.update(cx, |buffer, cx| buffer.apply_ops(ops, cx))?;
+ }
+ }
+ OpenBuffer::Loading(operations) => operations.extend_from_slice(&ops),
+ },
+ hash_map::Entry::Vacant(e) => {
+ if is_remote && this.loading_buffers.len() > 0 {
+ e.insert(OpenBuffer::Loading(ops));
+ }
}
}
Ok(())
@@ -2211,7 +2257,7 @@ impl Project {
) -> Result<()> {
this.update(&mut cx, |this, cx| {
let payload = envelope.payload.clone();
- let buffer_id = payload.buffer_id as usize;
+ let buffer_id = payload.buffer_id;
let file = payload.file.ok_or_else(|| anyhow!("invalid file"))?;
let worktree = this
.worktree_for_id(WorktreeId::from_proto(file.worktree_id), cx)
@@ -2237,21 +2283,30 @@ impl Project {
) -> Result<proto::BufferSaved> {
let buffer_id = envelope.payload.buffer_id;
let sender_id = envelope.original_sender_id()?;
- let (project_id, save) = this.update(&mut cx, |this, cx| {
+ let requested_version = envelope.payload.version.try_into()?;
+
+ let (project_id, buffer) = this.update(&mut cx, |this, _| {
let project_id = this.remote_id().ok_or_else(|| anyhow!("not connected"))?;
let buffer = this
.shared_buffers
.get(&sender_id)
.and_then(|shared_buffers| shared_buffers.get(&buffer_id).cloned())
.ok_or_else(|| anyhow!("unknown buffer id {}", buffer_id))?;
- Ok::<_, anyhow::Error>((project_id, buffer.update(cx, |buffer, cx| buffer.save(cx))))
+ Ok::<_, anyhow::Error>((project_id, buffer))
})?;
- let (version, mtime) = save.await?;
+ if !buffer
+ .read_with(&cx, |buffer, _| buffer.version())
+ .observed_all(&requested_version)
+ {
+ Err(anyhow!("save request depends on unreceived edits"))?;
+ }
+
+ let (saved_version, mtime) = buffer.update(&mut cx, |buffer, cx| buffer.save(cx)).await?;
Ok(proto::BufferSaved {
project_id,
buffer_id,
- version: (&version).into(),
+ version: (&saved_version).into(),
mtime: Some(mtime.into()),
})
}
@@ -2301,21 +2356,30 @@ impl Project {
.position
.and_then(language::proto::deserialize_anchor)
.ok_or_else(|| anyhow!("invalid position"))?;
- let completions = this.update(&mut cx, |this, cx| {
- let buffer = this
- .shared_buffers
+ let version = clock::Global::from(envelope.payload.version);
+ let buffer = this.read_with(&cx, |this, _| {
+ this.shared_buffers
.get(&sender_id)
.and_then(|shared_buffers| shared_buffers.get(&envelope.payload.buffer_id).cloned())
- .ok_or_else(|| anyhow!("unknown buffer id {}", envelope.payload.buffer_id))?;
- Ok::<_, anyhow::Error>(this.completions(&buffer, position, cx))
+ .ok_or_else(|| anyhow!("unknown buffer id {}", envelope.payload.buffer_id))
})?;
+ if !buffer
+ .read_with(&cx, |buffer, _| buffer.version())
+ .observed_all(&version)
+ {
+ Err(anyhow!("completion request depends on unreceived edits"))?;
+ }
+ let version = buffer.read_with(&cx, |buffer, _| buffer.version());
+ let completions = this
+ .update(&mut cx, |this, cx| this.completions(&buffer, position, cx))
+ .await?;
Ok(proto::GetCompletionsResponse {
completions: completions
- .await?
.iter()
.map(language::proto::serialize_completion)
.collect(),
+ version: (&version).into(),
})
}
@@ -2370,12 +2434,17 @@ impl Project {
.end
.and_then(language::proto::deserialize_anchor)
.ok_or_else(|| anyhow!("invalid end"))?;
- let code_actions = this.update(&mut cx, |this, cx| {
- let buffer = this
- .shared_buffers
+ let buffer = this.update(&mut cx, |this, _| {
+ this.shared_buffers
.get(&sender_id)
.and_then(|shared_buffers| shared_buffers.get(&envelope.payload.buffer_id).cloned())
- .ok_or_else(|| anyhow!("unknown buffer id {}", envelope.payload.buffer_id))?;
+ .ok_or_else(|| anyhow!("unknown buffer id {}", envelope.payload.buffer_id))
+ })?;
+ let version = buffer.read_with(&cx, |buffer, _| buffer.version());
+ if !version.observed(start.timestamp) || !version.observed(end.timestamp) {
+ Err(anyhow!("code action request references unreceived edits"))?;
+ }
+ let code_actions = this.update(&mut cx, |this, cx| {
Ok::<_, anyhow::Error>(this.code_actions(&buffer, start..end, cx))
})?;
@@ -2385,6 +2454,7 @@ impl Project {
.iter()
.map(language::proto::serialize_code_action)
.collect(),
+ version: (&version).into(),
})
}
@@ -2516,20 +2586,15 @@ impl Project {
push_to_history: bool,
cx: &mut ModelContext<Self>,
) -> Task<Result<ProjectTransaction>> {
- let mut project_transaction = ProjectTransaction::default();
- for (buffer, transaction) in message.buffers.into_iter().zip(message.transactions) {
- let buffer = match self.deserialize_buffer(buffer, cx) {
- Ok(buffer) => buffer,
- Err(error) => return Task::ready(Err(error)),
- };
- let transaction = match language::proto::deserialize_transaction(transaction) {
- Ok(transaction) => transaction,
- Err(error) => return Task::ready(Err(error)),
- };
- project_transaction.0.insert(buffer, transaction);
- }
-
- cx.spawn_weak(|_, mut cx| async move {
+ cx.spawn(|this, mut cx| async move {
+ let mut project_transaction = ProjectTransaction::default();
+ for (buffer, transaction) in message.buffers.into_iter().zip(message.transactions) {
+ let buffer = this
+ .update(&mut cx, |this, cx| this.deserialize_buffer(buffer, cx))
+ .await?;
+ let transaction = language::proto::deserialize_transaction(transaction)?;
+ project_transaction.0.insert(buffer, transaction);
+ }
for (buffer, transaction) in &project_transaction.0 {
buffer
.update(&mut cx, |buffer, _| {
@@ -2573,33 +2638,60 @@ impl Project {
&mut self,
buffer: proto::Buffer,
cx: &mut ModelContext<Self>,
- ) -> Result<ModelHandle<Buffer>> {
- match buffer.variant.ok_or_else(|| anyhow!("missing buffer"))? {
- proto::buffer::Variant::Id(id) => self
- .open_buffers
- .get(&(id as usize))
- .and_then(|buffer| buffer.upgrade(cx))
- .ok_or_else(|| anyhow!("no buffer exists for id {}", id)),
- proto::buffer::Variant::State(mut buffer) => {
- let mut buffer_worktree = None;
- let mut buffer_file = None;
- if let Some(file) = buffer.file.take() {
- let worktree_id = WorktreeId::from_proto(file.worktree_id);
- let worktree = self
- .worktree_for_id(worktree_id, cx)
- .ok_or_else(|| anyhow!("no worktree found for id {}", file.worktree_id))?;
- buffer_file = Some(Box::new(File::from_proto(file, worktree.clone(), cx)?)
- as Box<dyn language::File>);
- buffer_worktree = Some(worktree);
+ ) -> Task<Result<ModelHandle<Buffer>>> {
+ let replica_id = self.replica_id();
+
+ let mut opened_buffer_tx = self.opened_buffer.clone();
+ let mut opened_buffer_rx = self.opened_buffer.subscribe();
+ cx.spawn(|this, mut cx| async move {
+ match buffer.variant.ok_or_else(|| anyhow!("missing buffer"))? {
+ proto::buffer::Variant::Id(id) => {
+ let buffer = loop {
+ let buffer = this.read_with(&cx, |this, cx| {
+ this.open_buffers
+ .get(&id)
+ .and_then(|buffer| buffer.upgrade(cx))
+ });
+ if let Some(buffer) = buffer {
+ break buffer;
+ }
+ opened_buffer_rx
+ .recv()
+ .await
+ .ok_or_else(|| anyhow!("project dropped while waiting for buffer"))?;
+ };
+ Ok(buffer)
}
+ proto::buffer::Variant::State(mut buffer) => {
+ let mut buffer_worktree = None;
+ let mut buffer_file = None;
+ if let Some(file) = buffer.file.take() {
+ this.read_with(&cx, |this, cx| {
+ let worktree_id = WorktreeId::from_proto(file.worktree_id);
+ let worktree =
+ this.worktree_for_id(worktree_id, cx).ok_or_else(|| {
+ anyhow!("no worktree found for id {}", file.worktree_id)
+ })?;
+ buffer_file =
+ Some(Box::new(File::from_proto(file, worktree.clone(), cx)?)
+ as Box<dyn language::File>);
+ buffer_worktree = Some(worktree);
+ Ok::<_, anyhow::Error>(())
+ })?;
+ }
- let buffer = cx.add_model(|cx| {
- Buffer::from_proto(self.replica_id(), buffer, buffer_file, cx).unwrap()
- });
- self.register_buffer(&buffer, buffer_worktree.as_ref(), cx)?;
- Ok(buffer)
+ let buffer = cx.add_model(|cx| {
+ Buffer::from_proto(replica_id, buffer, buffer_file, cx).unwrap()
+ });
+ this.update(&mut cx, |this, cx| {
+ this.register_buffer(&buffer, buffer_worktree.as_ref(), cx)
+ })?;
+
+ let _ = opened_buffer_tx.send(()).await;
+ Ok(buffer)
+ }
}
- }
+ })
}
async fn handle_close_buffer(
@@ -2635,7 +2727,7 @@ impl Project {
this.update(&mut cx, |this, cx| {
let buffer = this
.open_buffers
- .get(&(envelope.payload.buffer_id as usize))
+ .get(&envelope.payload.buffer_id)
.and_then(|buffer| buffer.upgrade(cx));
if let Some(buffer) = buffer {
buffer.update(cx, |buffer, cx| {
@@ -2661,7 +2753,7 @@ impl Project {
this.update(&mut cx, |this, cx| {
let buffer = this
.open_buffers
- .get(&(payload.buffer_id as usize))
+ .get(&payload.buffer_id)
.and_then(|buffer| buffer.upgrade(cx));
if let Some(buffer) = buffer {
buffer.update(cx, |buffer, cx| {
@@ -2719,6 +2811,15 @@ impl WorktreeHandle {
}
}
+impl OpenBuffer {
+ pub fn upgrade(&self, cx: &impl UpgradeModelHandle) -> Option<ModelHandle<Buffer>> {
+ match self {
+ OpenBuffer::Loaded(handle) => handle.upgrade(cx),
+ OpenBuffer::Loading(_) => None,
+ }
+ }
+}
+
struct CandidateSet {
snapshot: Snapshot,
include_ignored: bool,
@@ -2959,7 +3060,7 @@ mod tests {
#[gpui::test]
async fn test_language_server_diagnostics(mut cx: gpui::TestAppContext) {
- let (language_server_config, mut fake_server) = LanguageServerConfig::fake(&cx).await;
+ let (language_server_config, mut fake_servers) = LanguageServerConfig::fake();
let progress_token = language_server_config
.disk_based_diagnostics_progress_token
.clone()
@@ -3022,6 +3123,7 @@ mod tests {
let mut events = subscribe(&project, &mut cx);
+ let mut fake_server = fake_servers.next().await.unwrap();
fake_server.start_progress(&progress_token).await;
assert_eq!(
events.next().await.unwrap(),
@@ -3123,7 +3225,7 @@ mod tests {
#[gpui::test]
async fn test_definition(mut cx: gpui::TestAppContext) {
- let (language_server_config, mut fake_server) = LanguageServerConfig::fake(&cx).await;
+ let (language_server_config, mut fake_servers) = LanguageServerConfig::fake();
let mut languages = LanguageRegistry::new();
languages.add(Arc::new(Language::new(
@@ -3178,6 +3280,7 @@ mod tests {
.await
.unwrap();
+ let mut fake_server = fake_servers.next().await.unwrap();
fake_server.handle_request::<lsp::request::GotoDefinition, _>(move |params| {
let params = params.text_document_position_params;
assert_eq!(
@@ -3371,7 +3474,7 @@ mod tests {
.await;
// Create a remote copy of this worktree.
- let initial_snapshot = tree.read_with(&cx, |tree, _| tree.snapshot());
+ let initial_snapshot = tree.read_with(&cx, |tree, _| tree.as_local().unwrap().snapshot());
let (remote, load_task) = cx.update(|cx| {
Worktree::remote(
1,
@@ -3447,10 +3550,13 @@ mod tests {
// Update the remote worktree. Check that it becomes consistent with the
// local worktree.
remote.update(&mut cx, |remote, cx| {
- let update_message =
- tree.read(cx)
- .snapshot()
- .build_update(&initial_snapshot, 1, 1, true);
+ let update_message = tree.read(cx).as_local().unwrap().snapshot().build_update(
+ &initial_snapshot,
+ 1,
+ 1,
+ 0,
+ true,
+ );
remote
.as_remote_mut()
.unwrap()
@@ -7,8 +7,11 @@ use ::ignore::gitignore::{Gitignore, GitignoreBuilder};
use anyhow::{anyhow, Result};
use client::{proto, Client, TypedEnvelope};
use clock::ReplicaId;
-use collections::HashMap;
-use futures::{Stream, StreamExt};
+use collections::{HashMap, VecDeque};
+use futures::{
+ channel::mpsc::{self, UnboundedSender},
+ Stream, StreamExt,
+};
use fuzzy::CharBag;
use gpui::{
executor, AppContext, AsyncAppContext, Entity, ModelContext, ModelHandle, MutableAppContext,
@@ -18,6 +21,7 @@ use language::{Buffer, DiagnosticEntry, Operation, PointUtf16, Rope};
use lazy_static::lazy_static;
use parking_lot::Mutex;
use postage::{
+ oneshot,
prelude::{Sink as _, Stream as _},
watch,
};
@@ -75,11 +79,13 @@ pub struct RemoteWorktree {
project_id: u64,
snapshot_rx: watch::Receiver<Snapshot>,
client: Arc<Client>,
- updates_tx: postage::mpsc::Sender<proto::UpdateWorktree>,
+ updates_tx: UnboundedSender<proto::UpdateWorktree>,
replica_id: ReplicaId,
queued_operations: Vec<(u64, Operation)>,
diagnostic_summaries: TreeMap<PathKey, DiagnosticSummary>,
weak: bool,
+ next_update_id: u64,
+ pending_updates: VecDeque<proto::UpdateWorktree>,
}
#[derive(Clone)]
@@ -208,7 +214,7 @@ impl Worktree {
entries_by_id: Default::default(),
};
- let (updates_tx, mut updates_rx) = postage::mpsc::channel(64);
+ let (updates_tx, mut updates_rx) = mpsc::unbounded();
let (mut snapshot_tx, snapshot_rx) = watch::channel_with(snapshot.clone());
let worktree_handle = cx.add_model(|_: &mut ModelContext<Worktree>| {
Worktree::Remote(RemoteWorktree {
@@ -233,6 +239,8 @@ impl Worktree {
}),
),
weak,
+ next_update_id: worktree.next_update_id,
+ pending_updates: Default::default(),
})
});
@@ -276,7 +284,7 @@ impl Worktree {
cx.background()
.spawn(async move {
- while let Some(update) = updates_rx.recv().await {
+ while let Some(update) = updates_rx.next().await {
let mut snapshot = snapshot_tx.borrow().clone();
if let Err(error) = snapshot.apply_remote_update(update) {
log::error!("error applying worktree update: {}", error);
@@ -450,7 +458,7 @@ impl LocalWorktree {
weak: bool,
fs: Arc<dyn Fs>,
cx: &mut AsyncAppContext,
- ) -> Result<(ModelHandle<Worktree>, Sender<ScanState>)> {
+ ) -> Result<(ModelHandle<Worktree>, UnboundedSender<ScanState>)> {
let abs_path = path.into();
let path: Arc<Path> = Arc::from(Path::new(""));
let next_entry_id = AtomicUsize::new(0);
@@ -470,7 +478,7 @@ impl LocalWorktree {
}
}
- let (scan_states_tx, scan_states_rx) = smol::channel::unbounded();
+ let (scan_states_tx, mut scan_states_rx) = mpsc::unbounded();
let (mut last_scan_state_tx, last_scan_state_rx) = watch::channel_with(ScanState::Scanning);
let tree = cx.add_model(move |cx: &mut ModelContext<Worktree>| {
let mut snapshot = LocalSnapshot {
@@ -515,7 +523,7 @@ impl LocalWorktree {
};
cx.spawn_weak(|this, mut cx| async move {
- while let Ok(scan_state) = scan_states_rx.recv().await {
+ while let Some(scan_state) = scan_states_rx.next().await {
if let Some(handle) = this.upgrade(&cx) {
let to_send = handle.update(&mut cx, |this, cx| {
last_scan_state_tx.blocking_send(scan_state).ok();
@@ -761,16 +769,41 @@ impl LocalWorktree {
let worktree_id = cx.model_id() as u64;
let (snapshots_to_send_tx, snapshots_to_send_rx) =
smol::channel::unbounded::<LocalSnapshot>();
+ let (mut share_tx, mut share_rx) = oneshot::channel();
let maintain_remote_snapshot = cx.background().spawn({
let rpc = rpc.clone();
let snapshot = snapshot.clone();
+ let diagnostic_summaries = self.diagnostic_summaries.clone();
+ let weak = self.weak;
async move {
+ if let Err(error) = rpc
+ .request(proto::ShareWorktree {
+ project_id,
+ worktree: Some(snapshot.to_proto(&diagnostic_summaries, weak)),
+ })
+ .await
+ {
+ let _ = share_tx.try_send(Err(error));
+ return;
+ } else {
+ let _ = share_tx.try_send(Ok(()));
+ }
+
+ let mut update_id = 0;
let mut prev_snapshot = snapshot;
while let Ok(snapshot) = snapshots_to_send_rx.recv().await {
- let message =
- snapshot.build_update(&prev_snapshot, project_id, worktree_id, false);
- match rpc.send(message) {
- Ok(()) => prev_snapshot = snapshot,
+ let message = snapshot.build_update(
+ &prev_snapshot,
+ project_id,
+ worktree_id,
+ update_id,
+ false,
+ );
+ match rpc.request(message).await {
+ Ok(_) => {
+ prev_snapshot = snapshot;
+ update_id += 1;
+ }
Err(err) => log::error!("error sending snapshot diff {}", err),
}
}
@@ -782,18 +815,11 @@ impl LocalWorktree {
_maintain_remote_snapshot: Some(maintain_remote_snapshot),
});
- let diagnostic_summaries = self.diagnostic_summaries.clone();
- let weak = self.weak;
- let share_message = cx.background().spawn(async move {
- proto::ShareWorktree {
- project_id,
- worktree: Some(snapshot.to_proto(&diagnostic_summaries, weak)),
- }
- });
-
cx.foreground().spawn(async move {
- rpc.request(share_message.await).await?;
- Ok(())
+ match share_rx.next().await {
+ Some(result) => result,
+ None => Err(anyhow!("unshared before sharing completed")),
+ }
})
}
@@ -814,19 +840,39 @@ impl RemoteWorktree {
pub fn update_from_remote(
&mut self,
envelope: TypedEnvelope<proto::UpdateWorktree>,
- cx: &mut ModelContext<Worktree>,
) -> Result<()> {
- let mut tx = self.updates_tx.clone();
- let payload = envelope.payload.clone();
- cx.foreground()
- .spawn(async move {
- tx.send(payload).await.expect("receiver runs to completion");
- })
- .detach();
+ let update = envelope.payload;
+ if update.id > self.next_update_id {
+ let ix = match self
+ .pending_updates
+ .binary_search_by_key(&update.id, |pending| pending.id)
+ {
+ Ok(ix) | Err(ix) => ix,
+ };
+ self.pending_updates.insert(ix, update);
+ } else {
+ let tx = self.updates_tx.clone();
+ self.next_update_id += 1;
+ tx.unbounded_send(update)
+ .expect("consumer runs to completion");
+ while let Some(update) = self.pending_updates.front() {
+ if update.id == self.next_update_id {
+ self.next_update_id += 1;
+ tx.unbounded_send(self.pending_updates.pop_front().unwrap())
+ .expect("consumer runs to completion");
+ } else {
+ break;
+ }
+ }
+ }
Ok(())
}
+ pub fn has_pending_updates(&self) -> bool {
+ !self.pending_updates.is_empty()
+ }
+
pub fn update_diagnostic_summary(
&mut self,
path: Arc<Path>,
@@ -849,94 +895,6 @@ impl Snapshot {
self.id
}
- pub(crate) fn to_proto(
- &self,
- diagnostic_summaries: &TreeMap<PathKey, DiagnosticSummary>,
- weak: bool,
- ) -> proto::Worktree {
- let root_name = self.root_name.clone();
- proto::Worktree {
- id: self.id.0 as u64,
- root_name,
- entries: self
- .entries_by_path
- .iter()
- .filter(|e| !e.is_ignored)
- .map(Into::into)
- .collect(),
- diagnostic_summaries: diagnostic_summaries
- .iter()
- .map(|(path, summary)| summary.to_proto(path.0.clone()))
- .collect(),
- weak,
- }
- }
-
- pub(crate) fn build_update(
- &self,
- other: &Self,
- project_id: u64,
- worktree_id: u64,
- include_ignored: bool,
- ) -> proto::UpdateWorktree {
- let mut updated_entries = Vec::new();
- let mut removed_entries = Vec::new();
- let mut self_entries = self
- .entries_by_id
- .cursor::<()>()
- .filter(|e| include_ignored || !e.is_ignored)
- .peekable();
- let mut other_entries = other
- .entries_by_id
- .cursor::<()>()
- .filter(|e| include_ignored || !e.is_ignored)
- .peekable();
- loop {
- match (self_entries.peek(), other_entries.peek()) {
- (Some(self_entry), Some(other_entry)) => {
- match Ord::cmp(&self_entry.id, &other_entry.id) {
- Ordering::Less => {
- let entry = self.entry_for_id(self_entry.id).unwrap().into();
- updated_entries.push(entry);
- self_entries.next();
- }
- Ordering::Equal => {
- if self_entry.scan_id != other_entry.scan_id {
- let entry = self.entry_for_id(self_entry.id).unwrap().into();
- updated_entries.push(entry);
- }
-
- self_entries.next();
- other_entries.next();
- }
- Ordering::Greater => {
- removed_entries.push(other_entry.id as u64);
- other_entries.next();
- }
- }
- }
- (Some(self_entry), None) => {
- let entry = self.entry_for_id(self_entry.id).unwrap().into();
- updated_entries.push(entry);
- self_entries.next();
- }
- (None, Some(other_entry)) => {
- removed_entries.push(other_entry.id as u64);
- other_entries.next();
- }
- (None, None) => break,
- }
- }
-
- proto::UpdateWorktree {
- project_id,
- worktree_id,
- root_name: self.root_name().to_string(),
- updated_entries,
- removed_entries,
- }
- }
-
pub(crate) fn apply_remote_update(&mut self, update: proto::UpdateWorktree) -> Result<()> {
let mut entries_by_path_edits = Vec::new();
let mut entries_by_id_edits = Vec::new();
@@ -1077,6 +1035,97 @@ impl Snapshot {
}
impl LocalSnapshot {
+ pub(crate) fn to_proto(
+ &self,
+ diagnostic_summaries: &TreeMap<PathKey, DiagnosticSummary>,
+ weak: bool,
+ ) -> proto::Worktree {
+ let root_name = self.root_name.clone();
+ proto::Worktree {
+ id: self.id.0 as u64,
+ root_name,
+ entries: self
+ .entries_by_path
+ .iter()
+ .filter(|e| !e.is_ignored)
+ .map(Into::into)
+ .collect(),
+ diagnostic_summaries: diagnostic_summaries
+ .iter()
+ .map(|(path, summary)| summary.to_proto(path.0.clone()))
+ .collect(),
+ weak,
+ next_update_id: 0,
+ }
+ }
+
+ pub(crate) fn build_update(
+ &self,
+ other: &Self,
+ project_id: u64,
+ worktree_id: u64,
+ update_id: u64,
+ include_ignored: bool,
+ ) -> proto::UpdateWorktree {
+ let mut updated_entries = Vec::new();
+ let mut removed_entries = Vec::new();
+ let mut self_entries = self
+ .entries_by_id
+ .cursor::<()>()
+ .filter(|e| include_ignored || !e.is_ignored)
+ .peekable();
+ let mut other_entries = other
+ .entries_by_id
+ .cursor::<()>()
+ .filter(|e| include_ignored || !e.is_ignored)
+ .peekable();
+ loop {
+ match (self_entries.peek(), other_entries.peek()) {
+ (Some(self_entry), Some(other_entry)) => {
+ match Ord::cmp(&self_entry.id, &other_entry.id) {
+ Ordering::Less => {
+ let entry = self.entry_for_id(self_entry.id).unwrap().into();
+ updated_entries.push(entry);
+ self_entries.next();
+ }
+ Ordering::Equal => {
+ if self_entry.scan_id != other_entry.scan_id {
+ let entry = self.entry_for_id(self_entry.id).unwrap().into();
+ updated_entries.push(entry);
+ }
+
+ self_entries.next();
+ other_entries.next();
+ }
+ Ordering::Greater => {
+ removed_entries.push(other_entry.id as u64);
+ other_entries.next();
+ }
+ }
+ }
+ (Some(self_entry), None) => {
+ let entry = self.entry_for_id(self_entry.id).unwrap().into();
+ updated_entries.push(entry);
+ self_entries.next();
+ }
+ (None, Some(other_entry)) => {
+ removed_entries.push(other_entry.id as u64);
+ other_entries.next();
+ }
+ (None, None) => break,
+ }
+ }
+
+ proto::UpdateWorktree {
+ id: update_id as u64,
+ project_id,
+ worktree_id,
+ root_name: self.root_name().to_string(),
+ updated_entries,
+ removed_entries,
+ }
+ }
+
fn insert_entry(&mut self, mut entry: Entry, fs: &dyn Fs) -> Entry {
if !entry.is_dir() && entry.path.file_name() == Some(&GITIGNORE) {
let abs_path = self.abs_path.join(&entry.path);
@@ -1283,13 +1332,29 @@ impl fmt::Debug for LocalWorktree {
impl fmt::Debug for Snapshot {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- for entry in self.entries_by_path.cursor::<()>() {
- for _ in entry.path.ancestors().skip(1) {
- write!(f, " ")?;
+ struct EntriesById<'a>(&'a SumTree<PathEntry>);
+ struct EntriesByPath<'a>(&'a SumTree<Entry>);
+
+ impl<'a> fmt::Debug for EntriesByPath<'a> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_map()
+ .entries(self.0.iter().map(|entry| (&entry.path, entry.id)))
+ .finish()
}
- writeln!(f, "{:?} (inode: {})", entry.path, entry.inode)?;
}
- Ok(())
+
+ impl<'a> fmt::Debug for EntriesById<'a> {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ f.debug_list().entries(self.0.iter()).finish()
+ }
+ }
+
+ f.debug_struct("Snapshot")
+ .field("id", &self.id)
+ .field("root_name", &self.root_name)
+ .field("entries_by_path", &EntriesByPath(&self.entries_by_path))
+ .field("entries_by_id", &EntriesById(&self.entries_by_id))
+ .finish()
}
}
@@ -1322,7 +1387,9 @@ impl language::File for File {
fn full_path(&self, cx: &AppContext) -> PathBuf {
let mut full_path = PathBuf::new();
full_path.push(self.worktree.read(cx).root_name());
- full_path.push(&self.path);
+ if self.path.components().next().is_some() {
+ full_path.push(&self.path);
+ }
full_path
}
@@ -1372,6 +1439,7 @@ impl language::File for File {
.request(proto::SaveBuffer {
project_id,
buffer_id,
+ version: (&version).into(),
})
.await?;
let version = response.version.try_into()?;
@@ -1493,7 +1561,7 @@ impl File {
}
}
-#[derive(Clone, Debug)]
+#[derive(Clone, Debug, PartialEq, Eq)]
pub struct Entry {
pub id: usize,
pub kind: EntryKind,
@@ -1504,7 +1572,7 @@ pub struct Entry {
pub is_ignored: bool,
}
-#[derive(Clone, Debug)]
+#[derive(Clone, Debug, PartialEq, Eq)]
pub enum EntryKind {
PendingDir,
Dir,
@@ -1668,14 +1736,14 @@ impl<'a> sum_tree::Dimension<'a, EntrySummary> for PathKey {
struct BackgroundScanner {
fs: Arc<dyn Fs>,
snapshot: Arc<Mutex<LocalSnapshot>>,
- notify: Sender<ScanState>,
+ notify: UnboundedSender<ScanState>,
executor: Arc<executor::Background>,
}
impl BackgroundScanner {
fn new(
snapshot: Arc<Mutex<LocalSnapshot>>,
- notify: Sender<ScanState>,
+ notify: UnboundedSender<ScanState>,
fs: Arc<dyn Fs>,
executor: Arc<executor::Background>,
) -> Self {
@@ -1696,28 +1764,27 @@ impl BackgroundScanner {
}
async fn run(mut self, events_rx: impl Stream<Item = Vec<fsevent::Event>>) {
- if self.notify.send(ScanState::Scanning).await.is_err() {
+ if self.notify.unbounded_send(ScanState::Scanning).is_err() {
return;
}
if let Err(err) = self.scan_dirs().await {
if self
.notify
- .send(ScanState::Err(Arc::new(err)))
- .await
+ .unbounded_send(ScanState::Err(Arc::new(err)))
.is_err()
{
return;
}
}
- if self.notify.send(ScanState::Idle).await.is_err() {
+ if self.notify.unbounded_send(ScanState::Idle).is_err() {
return;
}
futures::pin_mut!(events_rx);
while let Some(events) = events_rx.next().await {
- if self.notify.send(ScanState::Scanning).await.is_err() {
+ if self.notify.unbounded_send(ScanState::Scanning).is_err() {
break;
}
@@ -1725,7 +1792,7 @@ impl BackgroundScanner {
break;
}
- if self.notify.send(ScanState::Idle).await.is_err() {
+ if self.notify.unbounded_send(ScanState::Idle).is_err() {
break;
}
}
@@ -2391,7 +2458,7 @@ mod tests {
fmt::Write,
time::{SystemTime, UNIX_EPOCH},
};
- use util::test::temp_tree;
+ use util::{post_inc, test::temp_tree};
#[gpui::test]
async fn test_traversal(cx: gpui::TestAppContext) {
@@ -2503,7 +2570,7 @@ mod tests {
}
log::info!("Generated initial tree");
- let (notify_tx, _notify_rx) = smol::channel::unbounded();
+ let (notify_tx, _notify_rx) = mpsc::unbounded();
let fs = Arc::new(RealFs);
let next_entry_id = Arc::new(AtomicUsize::new(0));
let mut initial_snapshot = LocalSnapshot {
@@ -2563,7 +2630,7 @@ mod tests {
smol::block_on(scanner.process_events(events));
scanner.snapshot().check_invariants();
- let (notify_tx, _notify_rx) = smol::channel::unbounded();
+ let (notify_tx, _notify_rx) = mpsc::unbounded();
let mut new_scanner = BackgroundScanner::new(
Arc::new(Mutex::new(initial_snapshot)),
notify_tx,
@@ -2576,6 +2643,7 @@ mod tests {
new_scanner.snapshot().to_vec(true)
);
+ let mut update_id = 0;
for mut prev_snapshot in snapshots {
let include_ignored = rng.gen::<bool>();
if !include_ignored {
@@ -2596,9 +2664,13 @@ mod tests {
prev_snapshot.entries_by_id.edit(entries_by_id_edits, &());
}
- let update = scanner
- .snapshot()
- .build_update(&prev_snapshot, 0, 0, include_ignored);
+ let update = scanner.snapshot().build_update(
+ &prev_snapshot,
+ 0,
+ 0,
+ post_inc(&mut update_id),
+ include_ignored,
+ );
prev_snapshot.apply_remote_update(update).unwrap();
assert_eq!(
prev_snapshot.to_vec(true),
@@ -9,62 +9,63 @@ message Envelope {
Ack ack = 4;
Error error = 5;
Ping ping = 6;
-
- RegisterProject register_project = 7;
- RegisterProjectResponse register_project_response = 8;
- UnregisterProject unregister_project = 9;
- ShareProject share_project = 10;
- UnshareProject unshare_project = 11;
- JoinProject join_project = 12;
- JoinProjectResponse join_project_response = 13;
- LeaveProject leave_project = 14;
- AddProjectCollaborator add_project_collaborator = 15;
- RemoveProjectCollaborator remove_project_collaborator = 16;
- GetDefinition get_definition = 17;
- GetDefinitionResponse get_definition_response = 18;
-
- RegisterWorktree register_worktree = 19;
- UnregisterWorktree unregister_worktree = 20;
- ShareWorktree share_worktree = 21;
- UpdateWorktree update_worktree = 22;
- UpdateDiagnosticSummary update_diagnostic_summary = 23;
- DiskBasedDiagnosticsUpdating disk_based_diagnostics_updating = 24;
- DiskBasedDiagnosticsUpdated disk_based_diagnostics_updated = 25;
-
- OpenBuffer open_buffer = 26;
- OpenBufferResponse open_buffer_response = 27;
- CloseBuffer close_buffer = 28;
- UpdateBuffer update_buffer = 29;
- UpdateBufferFile update_buffer_file = 30;
- SaveBuffer save_buffer = 31;
- BufferSaved buffer_saved = 32;
- BufferReloaded buffer_reloaded = 33;
- FormatBuffers format_buffers = 34;
- FormatBuffersResponse format_buffers_response = 35;
- GetCompletions get_completions = 36;
- GetCompletionsResponse get_completions_response = 37;
- ApplyCompletionAdditionalEdits apply_completion_additional_edits = 38;
- ApplyCompletionAdditionalEditsResponse apply_completion_additional_edits_response = 39;
- GetCodeActions get_code_actions = 40;
- GetCodeActionsResponse get_code_actions_response = 41;
- ApplyCodeAction apply_code_action = 42;
- ApplyCodeActionResponse apply_code_action_response = 43;
-
- GetChannels get_channels = 44;
- GetChannelsResponse get_channels_response = 45;
- JoinChannel join_channel = 46;
- JoinChannelResponse join_channel_response = 47;
- LeaveChannel leave_channel = 48;
- SendChannelMessage send_channel_message = 49;
- SendChannelMessageResponse send_channel_message_response = 50;
- ChannelMessageSent channel_message_sent = 51;
- GetChannelMessages get_channel_messages = 52;
- GetChannelMessagesResponse get_channel_messages_response = 53;
-
- UpdateContacts update_contacts = 54;
-
- GetUsers get_users = 55;
- GetUsersResponse get_users_response = 56;
+ Test test = 7;
+
+ RegisterProject register_project = 8;
+ RegisterProjectResponse register_project_response = 9;
+ UnregisterProject unregister_project = 10;
+ ShareProject share_project = 11;
+ UnshareProject unshare_project = 12;
+ JoinProject join_project = 13;
+ JoinProjectResponse join_project_response = 14;
+ LeaveProject leave_project = 15;
+ AddProjectCollaborator add_project_collaborator = 16;
+ RemoveProjectCollaborator remove_project_collaborator = 17;
+ GetDefinition get_definition = 18;
+ GetDefinitionResponse get_definition_response = 19;
+
+ RegisterWorktree register_worktree = 20;
+ UnregisterWorktree unregister_worktree = 21;
+ ShareWorktree share_worktree = 22;
+ UpdateWorktree update_worktree = 23;
+ UpdateDiagnosticSummary update_diagnostic_summary = 24;
+ DiskBasedDiagnosticsUpdating disk_based_diagnostics_updating = 25;
+ DiskBasedDiagnosticsUpdated disk_based_diagnostics_updated = 26;
+
+ OpenBuffer open_buffer = 27;
+ OpenBufferResponse open_buffer_response = 28;
+ CloseBuffer close_buffer = 29;
+ UpdateBuffer update_buffer = 30;
+ UpdateBufferFile update_buffer_file = 31;
+ SaveBuffer save_buffer = 32;
+ BufferSaved buffer_saved = 33;
+ BufferReloaded buffer_reloaded = 34;
+ FormatBuffers format_buffers = 35;
+ FormatBuffersResponse format_buffers_response = 36;
+ GetCompletions get_completions = 37;
+ GetCompletionsResponse get_completions_response = 38;
+ ApplyCompletionAdditionalEdits apply_completion_additional_edits = 39;
+ ApplyCompletionAdditionalEditsResponse apply_completion_additional_edits_response = 40;
+ GetCodeActions get_code_actions = 41;
+ GetCodeActionsResponse get_code_actions_response = 42;
+ ApplyCodeAction apply_code_action = 43;
+ ApplyCodeActionResponse apply_code_action_response = 44;
+
+ GetChannels get_channels = 45;
+ GetChannelsResponse get_channels_response = 46;
+ JoinChannel join_channel = 47;
+ JoinChannelResponse join_channel_response = 48;
+ LeaveChannel leave_channel = 49;
+ SendChannelMessage send_channel_message = 50;
+ SendChannelMessageResponse send_channel_message_response = 51;
+ ChannelMessageSent channel_message_sent = 52;
+ GetChannelMessages get_channel_messages = 53;
+ GetChannelMessagesResponse get_channel_messages_response = 54;
+
+ UpdateContacts update_contacts = 55;
+
+ GetUsers get_users = 56;
+ GetUsersResponse get_users_response = 57;
}
}
@@ -78,6 +79,10 @@ message Error {
string message = 1;
}
+message Test {
+ uint64 id = 1;
+}
+
message RegisterProject {}
message RegisterProjectResponse {
@@ -128,11 +133,12 @@ message ShareWorktree {
}
message UpdateWorktree {
- uint64 project_id = 1;
- uint64 worktree_id = 2;
- string root_name = 3;
- repeated Entry updated_entries = 4;
- repeated uint64 removed_entries = 5;
+ uint64 id = 1;
+ uint64 project_id = 2;
+ uint64 worktree_id = 3;
+ string root_name = 4;
+ repeated Entry updated_entries = 5;
+ repeated uint64 removed_entries = 6;
}
message AddProjectCollaborator {
@@ -191,6 +197,7 @@ message UpdateBufferFile {
message SaveBuffer {
uint64 project_id = 1;
uint64 buffer_id = 2;
+ repeated VectorClockEntry version = 3;
}
message BufferSaved {
@@ -220,10 +227,12 @@ message GetCompletions {
uint64 project_id = 1;
uint64 buffer_id = 2;
Anchor position = 3;
+ repeated VectorClockEntry version = 4;
}
message GetCompletionsResponse {
repeated Completion completions = 1;
+ repeated VectorClockEntry version = 2;
}
message ApplyCompletionAdditionalEdits {
@@ -252,6 +261,7 @@ message GetCodeActions {
message GetCodeActionsResponse {
repeated CodeAction actions = 1;
+ repeated VectorClockEntry version = 2;
}
message ApplyCodeAction {
@@ -386,6 +396,7 @@ message Worktree {
repeated Entry entries = 3;
repeated DiagnosticSummary diagnostic_summaries = 4;
bool weak = 5;
+ uint64 next_update_id = 6;
}
message File {
@@ -265,7 +265,7 @@ impl Peer {
.await
.ok_or_else(|| anyhow!("connection was closed"))?;
if let Some(proto::envelope::Payload::Error(error)) = &response.payload {
- Err(anyhow!("request failed").context(error.message.clone()))
+ Err(anyhow!("RPC request failed - {}", error.message))
} else {
T::Response::from_envelope(response)
.ok_or_else(|| anyhow!("received response of the wrong type"))
@@ -402,40 +402,18 @@ mod tests {
assert_eq!(
client1
- .request(
- client1_conn_id,
- proto::OpenBuffer {
- project_id: 0,
- worktree_id: 1,
- path: "path/one".to_string(),
- },
- )
+ .request(client1_conn_id, proto::Test { id: 1 },)
.await
.unwrap(),
- proto::OpenBufferResponse {
- buffer: Some(proto::Buffer {
- variant: Some(proto::buffer::Variant::Id(0))
- }),
- }
+ proto::Test { id: 1 }
);
assert_eq!(
client2
- .request(
- client2_conn_id,
- proto::OpenBuffer {
- project_id: 0,
- worktree_id: 2,
- path: "path/two".to_string(),
- },
- )
+ .request(client2_conn_id, proto::Test { id: 2 })
.await
.unwrap(),
- proto::OpenBufferResponse {
- buffer: Some(proto::Buffer {
- variant: Some(proto::buffer::Variant::Id(1))
- })
- }
+ proto::Test { id: 2 }
);
client1.disconnect(client1_conn_id);
@@ -450,34 +428,9 @@ mod tests {
if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Ping>>() {
let receipt = envelope.receipt();
peer.respond(receipt, proto::Ack {})?
- } else if let Some(envelope) =
- envelope.downcast_ref::<TypedEnvelope<proto::OpenBuffer>>()
+ } else if let Some(envelope) = envelope.downcast_ref::<TypedEnvelope<proto::Test>>()
{
- let message = &envelope.payload;
- let receipt = envelope.receipt();
- let response = match message.path.as_str() {
- "path/one" => {
- assert_eq!(message.worktree_id, 1);
- proto::OpenBufferResponse {
- buffer: Some(proto::Buffer {
- variant: Some(proto::buffer::Variant::Id(0)),
- }),
- }
- }
- "path/two" => {
- assert_eq!(message.worktree_id, 2);
- proto::OpenBufferResponse {
- buffer: Some(proto::Buffer {
- variant: Some(proto::buffer::Variant::Id(1)),
- }),
- }
- }
- _ => {
- panic!("unexpected path {}", message.path);
- }
- };
-
- peer.respond(receipt, response)?
+ peer.respond(envelope.receipt(), envelope.payload.clone())?
} else {
panic!("unknown message type");
}
@@ -13,6 +13,7 @@ include!(concat!(env!("OUT_DIR"), "/zed.messages.rs"));
pub trait EnvelopedMessage: Clone + Sized + Send + Sync + 'static {
const NAME: &'static str;
+ const PRIORITY: MessagePriority;
fn into_envelope(
self,
id: u32,
@@ -35,6 +36,12 @@ pub trait AnyTypedEnvelope: 'static + Send + Sync {
fn payload_type_name(&self) -> &'static str;
fn as_any(&self) -> &dyn Any;
fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync>;
+ fn is_background(&self) -> bool;
+}
+
+pub enum MessagePriority {
+ Foreground,
+ Background,
}
impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
@@ -53,10 +60,14 @@ impl<T: EnvelopedMessage> AnyTypedEnvelope for TypedEnvelope<T> {
fn into_any(self: Box<Self>) -> Box<dyn Any + Send + Sync> {
self
}
+
+ fn is_background(&self) -> bool {
+ matches!(T::PRIORITY, MessagePriority::Background)
+ }
}
macro_rules! messages {
- ($($name:ident),* $(,)?) => {
+ ($(($name:ident, $priority:ident)),* $(,)?) => {
pub fn build_typed_envelope(sender_id: ConnectionId, envelope: Envelope) -> Option<Box<dyn AnyTypedEnvelope>> {
match envelope.payload {
$(Some(envelope::Payload::$name(payload)) => {
@@ -74,6 +85,7 @@ macro_rules! messages {
$(
impl EnvelopedMessage for $name {
const NAME: &'static str = std::stringify!($name);
+ const PRIORITY: MessagePriority = MessagePriority::$priority;
fn into_envelope(
self,
@@ -120,59 +132,60 @@ macro_rules! entity_messages {
}
messages!(
- Ack,
- AddProjectCollaborator,
- ApplyCodeAction,
- ApplyCodeActionResponse,
- ApplyCompletionAdditionalEdits,
- ApplyCompletionAdditionalEditsResponse,
- BufferReloaded,
- BufferSaved,
- ChannelMessageSent,
- CloseBuffer,
- DiskBasedDiagnosticsUpdated,
- DiskBasedDiagnosticsUpdating,
- Error,
- FormatBuffers,
- FormatBuffersResponse,
- GetChannelMessages,
- GetChannelMessagesResponse,
- GetChannels,
- GetChannelsResponse,
- GetCodeActions,
- GetCodeActionsResponse,
- GetCompletions,
- GetCompletionsResponse,
- GetDefinition,
- GetDefinitionResponse,
- GetUsers,
- GetUsersResponse,
- JoinChannel,
- JoinChannelResponse,
- JoinProject,
- JoinProjectResponse,
- LeaveChannel,
- LeaveProject,
- OpenBuffer,
- OpenBufferResponse,
- RegisterProjectResponse,
- Ping,
- RegisterProject,
- RegisterWorktree,
- RemoveProjectCollaborator,
- SaveBuffer,
- SendChannelMessage,
- SendChannelMessageResponse,
- ShareProject,
- ShareWorktree,
- UnregisterProject,
- UnregisterWorktree,
- UnshareProject,
- UpdateBuffer,
- UpdateBufferFile,
- UpdateContacts,
- UpdateDiagnosticSummary,
- UpdateWorktree,
+ (Ack, Foreground),
+ (AddProjectCollaborator, Foreground),
+ (ApplyCodeAction, Foreground),
+ (ApplyCodeActionResponse, Foreground),
+ (ApplyCompletionAdditionalEdits, Foreground),
+ (ApplyCompletionAdditionalEditsResponse, Foreground),
+ (BufferReloaded, Foreground),
+ (BufferSaved, Foreground),
+ (ChannelMessageSent, Foreground),
+ (CloseBuffer, Foreground),
+ (DiskBasedDiagnosticsUpdated, Background),
+ (DiskBasedDiagnosticsUpdating, Background),
+ (Error, Foreground),
+ (FormatBuffers, Foreground),
+ (FormatBuffersResponse, Foreground),
+ (GetChannelMessages, Foreground),
+ (GetChannelMessagesResponse, Foreground),
+ (GetChannels, Foreground),
+ (GetChannelsResponse, Foreground),
+ (GetCodeActions, Background),
+ (GetCodeActionsResponse, Foreground),
+ (GetCompletions, Background),
+ (GetCompletionsResponse, Foreground),
+ (GetDefinition, Foreground),
+ (GetDefinitionResponse, Foreground),
+ (GetUsers, Foreground),
+ (GetUsersResponse, Foreground),
+ (JoinChannel, Foreground),
+ (JoinChannelResponse, Foreground),
+ (JoinProject, Foreground),
+ (JoinProjectResponse, Foreground),
+ (LeaveChannel, Foreground),
+ (LeaveProject, Foreground),
+ (OpenBuffer, Foreground),
+ (OpenBufferResponse, Foreground),
+ (RegisterProjectResponse, Foreground),
+ (Ping, Foreground),
+ (RegisterProject, Foreground),
+ (RegisterWorktree, Foreground),
+ (RemoveProjectCollaborator, Foreground),
+ (SaveBuffer, Foreground),
+ (SendChannelMessage, Foreground),
+ (SendChannelMessageResponse, Foreground),
+ (ShareProject, Foreground),
+ (ShareWorktree, Foreground),
+ (Test, Foreground),
+ (UnregisterProject, Foreground),
+ (UnregisterWorktree, Foreground),
+ (UnshareProject, Foreground),
+ (UpdateBuffer, Foreground),
+ (UpdateBufferFile, Foreground),
+ (UpdateContacts, Foreground),
+ (UpdateDiagnosticSummary, Foreground),
+ (UpdateWorktree, Foreground),
);
request_messages!(
@@ -198,7 +211,9 @@ request_messages!(
(SendChannelMessage, SendChannelMessageResponse),
(ShareProject, Ack),
(ShareWorktree, Ack),
+ (Test, Test),
(UpdateBuffer, Ack),
+ (UpdateWorktree, Ack),
);
entity_messages!(
@@ -256,12 +271,19 @@ where
{
/// Write a given protobuf message to the stream.
pub async fn write_message(&mut self, message: &Envelope) -> Result<(), WebSocketError> {
+ #[cfg(any(test, feature = "test-support"))]
+ const COMPRESSION_LEVEL: i32 = -7;
+
+ #[cfg(not(any(test, feature = "test-support")))]
+ const COMPRESSION_LEVEL: i32 = 4;
+
self.encoding_buffer.resize(message.encoded_len(), 0);
self.encoding_buffer.clear();
message
.encode(&mut self.encoding_buffer)
.map_err(|err| io::Error::from(err))?;
- let buffer = zstd::stream::encode_all(self.encoding_buffer.as_slice(), 4).unwrap();
+ let buffer =
+ zstd::stream::encode_all(self.encoding_buffer.as_slice(), COMPRESSION_LEVEL).unwrap();
self.stream.send(WebSocketMessage::Binary(buffer)).await?;
Ok(())
}
@@ -15,7 +15,6 @@ required-features = ["seed-support"]
[dependencies]
collections = { path = "../collections" }
rpc = { path = "../rpc" }
-
anyhow = "1.0.40"
async-std = { version = "1.8.0", features = ["attributes"] }
async-trait = "0.1.50"
@@ -57,12 +56,12 @@ features = ["runtime-async-std-rustls", "postgres", "time", "uuid"]
[dev-dependencies]
collections = { path = "../collections", features = ["test-support"] }
-gpui = { path = "../gpui" }
+gpui = { path = "../gpui", features = ["test-support"] }
+rpc = { path = "../rpc", features = ["test-support"] }
zed = { path = "../zed", features = ["test-support"] }
ctor = "0.1"
env_logger = "0.8"
util = { path = "../util" }
-
lazy_static = "1.4"
serde_json = { version = "1.0.64", features = ["preserve_order"] }
@@ -111,7 +111,7 @@ async fn create_access_token(request: Request) -> tide::Result {
.get_user_by_github_login(request.param("github_login")?)
.await?
.ok_or_else(|| surf::Error::from_str(StatusCode::NotFound, "user not found"))?;
- let access_token = auth::create_access_token(request.db(), user.id).await?;
+ let access_token = auth::create_access_token(request.db().as_ref(), user.id).await?;
#[derive(Deserialize)]
struct QueryParams {
@@ -234,7 +234,7 @@ async fn get_auth_callback(mut request: Request) -> tide::Result {
let mut user_id = user.id;
if let Some(impersonated_login) = app_sign_in_params.impersonate {
log::info!("attempting to impersonate user @{}", impersonated_login);
- if let Some(user) = request.db().get_users_by_ids([user_id]).await?.first() {
+ if let Some(user) = request.db().get_users_by_ids(vec![user_id]).await?.first() {
if user.admin {
user_id = request.db().create_user(&impersonated_login, false).await?;
log::info!("impersonating user {}", user_id.0);
@@ -244,7 +244,7 @@ async fn get_auth_callback(mut request: Request) -> tide::Result {
}
}
- let access_token = create_access_token(request.db(), user_id).await?;
+ let access_token = create_access_token(request.db().as_ref(), user_id).await?;
let encrypted_access_token = encrypt_access_token(
&access_token,
app_sign_in_params.native_app_public_key.clone(),
@@ -267,7 +267,7 @@ async fn post_sign_out(mut request: Request) -> tide::Result {
const MAX_ACCESS_TOKENS_TO_STORE: usize = 8;
-pub async fn create_access_token(db: &db::Db, user_id: UserId) -> tide::Result<String> {
+pub async fn create_access_token(db: &dyn db::Db, user_id: UserId) -> tide::Result<String> {
let access_token = zed_auth::random_token();
let access_token_hash =
hash_access_token(&access_token).context("failed to hash access token")?;
@@ -1,11 +1,12 @@
use anyhow::Context;
+use anyhow::Result;
+pub use async_sqlx_session::PostgresSessionStore as SessionStore;
use async_std::task::{block_on, yield_now};
+use async_trait::async_trait;
use serde::Serialize;
-use sqlx::{types::Uuid, FromRow, Result};
-use time::OffsetDateTime;
-
-pub use async_sqlx_session::PostgresSessionStore as SessionStore;
pub use sqlx::postgres::PgPoolOptions as DbOptions;
+use sqlx::{types::Uuid, FromRow};
+use time::OffsetDateTime;
macro_rules! test_support {
($self:ident, { $($token:tt)* }) => {{
@@ -21,13 +22,77 @@ macro_rules! test_support {
}};
}
-#[derive(Clone)]
-pub struct Db {
+#[async_trait]
+pub trait Db: Send + Sync {
+ async fn create_signup(
+ &self,
+ github_login: &str,
+ email_address: &str,
+ about: &str,
+ wants_releases: bool,
+ wants_updates: bool,
+ wants_community: bool,
+ ) -> Result<SignupId>;
+ async fn get_all_signups(&self) -> Result<Vec<Signup>>;
+ async fn destroy_signup(&self, id: SignupId) -> Result<()>;
+ async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId>;
+ async fn get_all_users(&self) -> Result<Vec<User>>;
+ async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>>;
+ async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>>;
+ async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>>;
+ async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()>;
+ async fn destroy_user(&self, id: UserId) -> Result<()>;
+ async fn create_access_token_hash(
+ &self,
+ user_id: UserId,
+ access_token_hash: &str,
+ max_access_token_count: usize,
+ ) -> Result<()>;
+ async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>>;
+ #[cfg(any(test, feature = "seed-support"))]
+ async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>>;
+ #[cfg(any(test, feature = "seed-support"))]
+ async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId>;
+ #[cfg(any(test, feature = "seed-support"))]
+ async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()>;
+ #[cfg(any(test, feature = "seed-support"))]
+ async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId>;
+ #[cfg(any(test, feature = "seed-support"))]
+ async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>>;
+ async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>>;
+ async fn can_user_access_channel(&self, user_id: UserId, channel_id: ChannelId)
+ -> Result<bool>;
+ #[cfg(any(test, feature = "seed-support"))]
+ async fn add_channel_member(
+ &self,
+ channel_id: ChannelId,
+ user_id: UserId,
+ is_admin: bool,
+ ) -> Result<()>;
+ async fn create_channel_message(
+ &self,
+ channel_id: ChannelId,
+ sender_id: UserId,
+ body: &str,
+ timestamp: OffsetDateTime,
+ nonce: u128,
+ ) -> Result<MessageId>;
+ async fn get_channel_messages(
+ &self,
+ channel_id: ChannelId,
+ count: usize,
+ before_id: Option<MessageId>,
+ ) -> Result<Vec<ChannelMessage>>;
+ #[cfg(test)]
+ async fn teardown(&self, name: &str, url: &str);
+}
+
+pub struct PostgresDb {
pool: sqlx::PgPool,
test_mode: bool,
}
-impl Db {
+impl PostgresDb {
pub async fn new(url: &str, max_connections: u32) -> tide::Result<Self> {
let pool = DbOptions::new()
.max_connections(max_connections)
@@ -39,10 +104,12 @@ impl Db {
test_mode: false,
})
}
+}
+#[async_trait]
+impl Db for PostgresDb {
// signups
-
- pub async fn create_signup(
+ async fn create_signup(
&self,
github_login: &str,
email_address: &str,
@@ -64,7 +131,7 @@ impl Db {
VALUES ($1, $2, $3, $4, $5, $6)
RETURNING id
";
- sqlx::query_scalar(query)
+ Ok(sqlx::query_scalar(query)
.bind(github_login)
.bind(email_address)
.bind(about)
@@ -73,31 +140,31 @@ impl Db {
.bind(wants_community)
.fetch_one(&self.pool)
.await
- .map(SignupId)
+ .map(SignupId)?)
})
}
- pub async fn get_all_signups(&self) -> Result<Vec<Signup>> {
+ async fn get_all_signups(&self) -> Result<Vec<Signup>> {
test_support!(self, {
let query = "SELECT * FROM signups ORDER BY github_login ASC";
- sqlx::query_as(query).fetch_all(&self.pool).await
+ Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
})
}
- pub async fn destroy_signup(&self, id: SignupId) -> Result<()> {
+ async fn destroy_signup(&self, id: SignupId) -> Result<()> {
test_support!(self, {
let query = "DELETE FROM signups WHERE id = $1";
- sqlx::query(query)
+ Ok(sqlx::query(query)
.bind(id.0)
.execute(&self.pool)
.await
- .map(drop)
+ .map(drop)?)
})
}
// users
- pub async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
+ async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
test_support!(self, {
let query = "
INSERT INTO users (github_login, admin)
@@ -105,31 +172,28 @@ impl Db {
ON CONFLICT (github_login) DO UPDATE SET github_login = excluded.github_login
RETURNING id
";
- sqlx::query_scalar(query)
+ Ok(sqlx::query_scalar(query)
.bind(github_login)
.bind(admin)
.fetch_one(&self.pool)
.await
- .map(UserId)
+ .map(UserId)?)
})
}
- pub async fn get_all_users(&self) -> Result<Vec<User>> {
+ async fn get_all_users(&self) -> Result<Vec<User>> {
test_support!(self, {
let query = "SELECT * FROM users ORDER BY github_login ASC";
- sqlx::query_as(query).fetch_all(&self.pool).await
+ Ok(sqlx::query_as(query).fetch_all(&self.pool).await?)
})
}
- pub async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
- let users = self.get_users_by_ids([id]).await?;
+ async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
+ let users = self.get_users_by_ids(vec![id]).await?;
Ok(users.into_iter().next())
}
- pub async fn get_users_by_ids(
- &self,
- ids: impl IntoIterator<Item = UserId>,
- ) -> Result<Vec<User>> {
+ async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
let ids = ids.into_iter().map(|id| id.0).collect::<Vec<_>>();
test_support!(self, {
let query = "
@@ -138,33 +202,36 @@ impl Db {
WHERE users.id = ANY ($1)
";
- sqlx::query_as(query).bind(&ids).fetch_all(&self.pool).await
+ Ok(sqlx::query_as(query)
+ .bind(&ids)
+ .fetch_all(&self.pool)
+ .await?)
})
}
- pub async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
+ async fn get_user_by_github_login(&self, github_login: &str) -> Result<Option<User>> {
test_support!(self, {
let query = "SELECT * FROM users WHERE github_login = $1 LIMIT 1";
- sqlx::query_as(query)
+ Ok(sqlx::query_as(query)
.bind(github_login)
.fetch_optional(&self.pool)
- .await
+ .await?)
})
}
- pub async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
+ async fn set_user_is_admin(&self, id: UserId, is_admin: bool) -> Result<()> {
test_support!(self, {
let query = "UPDATE users SET admin = $1 WHERE id = $2";
- sqlx::query(query)
+ Ok(sqlx::query(query)
.bind(is_admin)
.bind(id.0)
.execute(&self.pool)
.await
- .map(drop)
+ .map(drop)?)
})
}
- pub async fn destroy_user(&self, id: UserId) -> Result<()> {
+ async fn destroy_user(&self, id: UserId) -> Result<()> {
test_support!(self, {
let query = "DELETE FROM access_tokens WHERE user_id = $1;";
sqlx::query(query)
@@ -173,17 +240,17 @@ impl Db {
.await
.map(drop)?;
let query = "DELETE FROM users WHERE id = $1;";
- sqlx::query(query)
+ Ok(sqlx::query(query)
.bind(id.0)
.execute(&self.pool)
.await
- .map(drop)
+ .map(drop)?)
})
}
// access tokens
- pub async fn create_access_token_hash(
+ async fn create_access_token_hash(
&self,
user_id: UserId,
access_token_hash: &str,
@@ -216,11 +283,11 @@ impl Db {
.bind(max_access_token_count as u32)
.execute(&mut tx)
.await?;
- tx.commit().await
+ Ok(tx.commit().await?)
})
}
- pub async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
+ async fn get_access_token_hashes(&self, user_id: UserId) -> Result<Vec<String>> {
test_support!(self, {
let query = "
SELECT hash
@@ -228,10 +295,10 @@ impl Db {
WHERE user_id = $1
ORDER BY id DESC
";
- sqlx::query_scalar(query)
+ Ok(sqlx::query_scalar(query)
.bind(user_id.0)
.fetch_all(&self.pool)
- .await
+ .await?)
})
}
@@ -239,82 +306,77 @@ impl Db {
#[allow(unused)] // Help rust-analyzer
#[cfg(any(test, feature = "seed-support"))]
- pub async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
+ async fn find_org_by_slug(&self, slug: &str) -> Result<Option<Org>> {
test_support!(self, {
let query = "
SELECT *
FROM orgs
WHERE slug = $1
";
- sqlx::query_as(query)
+ Ok(sqlx::query_as(query)
.bind(slug)
.fetch_optional(&self.pool)
- .await
+ .await?)
})
}
#[cfg(any(test, feature = "seed-support"))]
- pub async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
+ async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
test_support!(self, {
let query = "
INSERT INTO orgs (name, slug)
VALUES ($1, $2)
RETURNING id
";
- sqlx::query_scalar(query)
+ Ok(sqlx::query_scalar(query)
.bind(name)
.bind(slug)
.fetch_one(&self.pool)
.await
- .map(OrgId)
+ .map(OrgId)?)
})
}
#[cfg(any(test, feature = "seed-support"))]
- pub async fn add_org_member(
- &self,
- org_id: OrgId,
- user_id: UserId,
- is_admin: bool,
- ) -> Result<()> {
+ async fn add_org_member(&self, org_id: OrgId, user_id: UserId, is_admin: bool) -> Result<()> {
test_support!(self, {
let query = "
INSERT INTO org_memberships (org_id, user_id, admin)
VALUES ($1, $2, $3)
ON CONFLICT DO NOTHING
";
- sqlx::query(query)
+ Ok(sqlx::query(query)
.bind(org_id.0)
.bind(user_id.0)
.bind(is_admin)
.execute(&self.pool)
.await
- .map(drop)
+ .map(drop)?)
})
}
// channels
#[cfg(any(test, feature = "seed-support"))]
- pub async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
+ async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
test_support!(self, {
let query = "
INSERT INTO channels (owner_id, owner_is_user, name)
VALUES ($1, false, $2)
RETURNING id
";
- sqlx::query_scalar(query)
+ Ok(sqlx::query_scalar(query)
.bind(org_id.0)
.bind(name)
.fetch_one(&self.pool)
.await
- .map(ChannelId)
+ .map(ChannelId)?)
})
}
#[allow(unused)] // Help rust-analyzer
#[cfg(any(test, feature = "seed-support"))]
- pub async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
+ async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
test_support!(self, {
let query = "
SELECT *
@@ -323,32 +385,32 @@ impl Db {
channels.owner_is_user = false AND
channels.owner_id = $1
";
- sqlx::query_as(query)
+ Ok(sqlx::query_as(query)
.bind(org_id.0)
.fetch_all(&self.pool)
- .await
+ .await?)
})
}
- pub async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
+ async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
test_support!(self, {
let query = "
SELECT
- channels.id, channels.name
+ channels.*
FROM
channel_memberships, channels
WHERE
channel_memberships.user_id = $1 AND
channel_memberships.channel_id = channels.id
";
- sqlx::query_as(query)
+ Ok(sqlx::query_as(query)
.bind(user_id.0)
.fetch_all(&self.pool)
- .await
+ .await?)
})
}
- pub async fn can_user_access_channel(
+ async fn can_user_access_channel(
&self,
user_id: UserId,
channel_id: ChannelId,
@@ -360,17 +422,17 @@ impl Db {
WHERE user_id = $1 AND channel_id = $2
LIMIT 1
";
- sqlx::query_scalar::<_, i32>(query)
+ Ok(sqlx::query_scalar::<_, i32>(query)
.bind(user_id.0)
.bind(channel_id.0)
.fetch_optional(&self.pool)
.await
- .map(|e| e.is_some())
+ .map(|e| e.is_some())?)
})
}
#[cfg(any(test, feature = "seed-support"))]
- pub async fn add_channel_member(
+ async fn add_channel_member(
&self,
channel_id: ChannelId,
user_id: UserId,
@@ -382,19 +444,19 @@ impl Db {
VALUES ($1, $2, $3)
ON CONFLICT DO NOTHING
";
- sqlx::query(query)
+ Ok(sqlx::query(query)
.bind(channel_id.0)
.bind(user_id.0)
.bind(is_admin)
.execute(&self.pool)
.await
- .map(drop)
+ .map(drop)?)
})
}
// messages
- pub async fn create_channel_message(
+ async fn create_channel_message(
&self,
channel_id: ChannelId,
sender_id: UserId,
@@ -409,7 +471,7 @@ impl Db {
ON CONFLICT (nonce) DO UPDATE SET nonce = excluded.nonce
RETURNING id
";
- sqlx::query_scalar(query)
+ Ok(sqlx::query_scalar(query)
.bind(channel_id.0)
.bind(sender_id.0)
.bind(body)
@@ -417,11 +479,11 @@ impl Db {
.bind(Uuid::from_u128(nonce))
.fetch_one(&self.pool)
.await
- .map(MessageId)
+ .map(MessageId)?)
})
}
- pub async fn get_channel_messages(
+ async fn get_channel_messages(
&self,
channel_id: ChannelId,
count: usize,
@@ -431,7 +493,7 @@ impl Db {
let query = r#"
SELECT * FROM (
SELECT
- id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
+ id, channel_id, sender_id, body, sent_at AT TIME ZONE 'UTC' as sent_at, nonce
FROM
channel_messages
WHERE
@@ -442,12 +504,34 @@ impl Db {
) as recent_messages
ORDER BY id ASC
"#;
- sqlx::query_as(query)
+ Ok(sqlx::query_as(query)
.bind(channel_id.0)
.bind(before_id.unwrap_or(MessageId::MAX))
.bind(count as i64)
.fetch_all(&self.pool)
+ .await?)
+ })
+ }
+
+ #[cfg(test)]
+ async fn teardown(&self, name: &str, url: &str) {
+ use util::ResultExt;
+
+ test_support!(self, {
+ let query = "
+ SELECT pg_terminate_backend(pg_stat_activity.pid)
+ FROM pg_stat_activity
+ WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
+ ";
+ sqlx::query(query)
+ .bind(name)
+ .execute(&self.pool)
.await
+ .log_err();
+ self.pool.close().await;
+ <sqlx::Postgres as sqlx::migrate::MigrateDatabase>::drop_database(url)
+ .await
+ .log_err();
})
}
}
@@ -479,7 +563,7 @@ macro_rules! id_type {
}
id_type!(UserId);
-#[derive(Debug, FromRow, Serialize, PartialEq)]
+#[derive(Clone, Debug, FromRow, Serialize, PartialEq)]
pub struct User {
pub id: UserId,
pub github_login: String,
@@ -507,16 +591,19 @@ pub struct Signup {
}
id_type!(ChannelId);
-#[derive(Debug, FromRow, Serialize)]
+#[derive(Clone, Debug, FromRow, Serialize)]
pub struct Channel {
pub id: ChannelId,
pub name: String,
+ pub owner_id: i32,
+ pub owner_is_user: bool,
}
id_type!(MessageId);
-#[derive(Debug, FromRow)]
+#[derive(Clone, Debug, FromRow)]
pub struct ChannelMessage {
pub id: MessageId,
+ pub channel_id: ChannelId,
pub sender_id: UserId,
pub body: String,
pub sent_at: OffsetDateTime,
@@ -526,6 +613,9 @@ pub struct ChannelMessage {
#[cfg(test)]
pub mod tests {
use super::*;
+ use anyhow::anyhow;
+ use collections::BTreeMap;
+ use gpui::{executor::Background, TestAppContext};
use lazy_static::lazy_static;
use parking_lot::Mutex;
use rand::prelude::*;
@@ -533,217 +623,119 @@ pub mod tests {
migrate::{MigrateDatabase, Migrator},
Postgres,
};
- use std::{
- mem,
- path::Path,
- sync::atomic::{AtomicUsize, Ordering::SeqCst},
- };
- use util::ResultExt as _;
-
- pub struct TestDb {
- pub db: Option<Db>,
- pub name: String,
- pub url: String,
- }
+ use std::{path::Path, sync::Arc};
+ use util::post_inc;
- lazy_static! {
- static ref DB_POOL: Mutex<Vec<TestDb>> = Default::default();
- static ref DB_COUNT: AtomicUsize = Default::default();
- }
+ #[gpui::test]
+ async fn test_get_users_by_ids(cx: TestAppContext) {
+ for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
+ let db = test_db.db();
- impl TestDb {
- pub fn new() -> Self {
- DB_COUNT.fetch_add(1, SeqCst);
- let mut pool = DB_POOL.lock();
- if let Some(db) = pool.pop() {
- db.truncate();
- db
- } else {
- let mut rng = StdRng::from_entropy();
- let name = format!("zed-test-{}", rng.gen::<u128>());
- let url = format!("postgres://postgres@localhost/{}", name);
- let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
- let db = block_on(async {
- Postgres::create_database(&url)
- .await
- .expect("failed to create test db");
- let mut db = Db::new(&url, 5).await.unwrap();
- db.test_mode = true;
- let migrator = Migrator::new(migrations_path).await.unwrap();
- migrator.run(&db.pool).await.unwrap();
- db
- });
-
- Self {
- db: Some(db),
- name,
- url,
- }
- }
- }
+ let user = db.create_user("user", false).await.unwrap();
+ let friend1 = db.create_user("friend-1", false).await.unwrap();
+ let friend2 = db.create_user("friend-2", false).await.unwrap();
+ let friend3 = db.create_user("friend-3", false).await.unwrap();
- pub fn db(&self) -> &Db {
- self.db.as_ref().unwrap()
- }
-
- fn truncate(&self) {
- block_on(async {
- let query = "
- SELECT tablename FROM pg_tables
- WHERE schemaname = 'public';
- ";
- let table_names = sqlx::query_scalar::<_, String>(query)
- .fetch_all(&self.db().pool)
+ assert_eq!(
+ db.get_users_by_ids(vec![user, friend1, friend2, friend3])
.await
- .unwrap();
- sqlx::query(&format!(
- "TRUNCATE TABLE {} RESTART IDENTITY",
- table_names.join(", ")
- ))
- .execute(&self.db().pool)
- .await
- .unwrap();
- })
- }
-
- async fn teardown(mut self) -> Result<()> {
- let db = self.db.take().unwrap();
- let query = "
- SELECT pg_terminate_backend(pg_stat_activity.pid)
- FROM pg_stat_activity
- WHERE pg_stat_activity.datname = '{}' AND pid <> pg_backend_pid();
- ";
- sqlx::query(query)
- .bind(&self.name)
- .execute(&db.pool)
- .await?;
- db.pool.close().await;
- Postgres::drop_database(&self.url).await?;
- Ok(())
- }
- }
-
- impl Drop for TestDb {
- fn drop(&mut self) {
- if let Some(db) = self.db.take() {
- DB_POOL.lock().push(TestDb {
- db: Some(db),
- name: mem::take(&mut self.name),
- url: mem::take(&mut self.url),
- });
- if DB_COUNT.fetch_sub(1, SeqCst) == 1 {
- block_on(async move {
- let mut pool = DB_POOL.lock();
- for db in pool.drain(..) {
- db.teardown().await.log_err();
- }
- });
- }
- }
+ .unwrap(),
+ vec![
+ User {
+ id: user,
+ github_login: "user".to_string(),
+ admin: false,
+ },
+ User {
+ id: friend1,
+ github_login: "friend-1".to_string(),
+ admin: false,
+ },
+ User {
+ id: friend2,
+ github_login: "friend-2".to_string(),
+ admin: false,
+ },
+ User {
+ id: friend3,
+ github_login: "friend-3".to_string(),
+ admin: false,
+ }
+ ]
+ );
}
}
#[gpui::test]
- async fn test_get_users_by_ids() {
- let test_db = TestDb::new();
- let db = test_db.db();
-
- let user = db.create_user("user", false).await.unwrap();
- let friend1 = db.create_user("friend-1", false).await.unwrap();
- let friend2 = db.create_user("friend-2", false).await.unwrap();
- let friend3 = db.create_user("friend-3", false).await.unwrap();
-
- assert_eq!(
- db.get_users_by_ids([user, friend1, friend2, friend3])
+ async fn test_recent_channel_messages(cx: TestAppContext) {
+ for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
+ let db = test_db.db();
+ let user = db.create_user("user", false).await.unwrap();
+ let org = db.create_org("org", "org").await.unwrap();
+ let channel = db.create_org_channel(org, "channel").await.unwrap();
+ for i in 0..10 {
+ db.create_channel_message(
+ channel,
+ user,
+ &i.to_string(),
+ OffsetDateTime::now_utc(),
+ i,
+ )
.await
- .unwrap(),
- vec![
- User {
- id: user,
- github_login: "user".to_string(),
- admin: false,
- },
- User {
- id: friend1,
- github_login: "friend-1".to_string(),
- admin: false,
- },
- User {
- id: friend2,
- github_login: "friend-2".to_string(),
- admin: false,
- },
- User {
- id: friend3,
- github_login: "friend-3".to_string(),
- admin: false,
- }
- ]
- );
- }
+ .unwrap();
+ }
- #[gpui::test]
- async fn test_recent_channel_messages() {
- let test_db = TestDb::new();
- let db = test_db.db();
- let user = db.create_user("user", false).await.unwrap();
- let org = db.create_org("org", "org").await.unwrap();
- let channel = db.create_org_channel(org, "channel").await.unwrap();
- for i in 0..10 {
- db.create_channel_message(channel, user, &i.to_string(), OffsetDateTime::now_utc(), i)
+ let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
+ assert_eq!(
+ messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
+ ["5", "6", "7", "8", "9"]
+ );
+
+ let prev_messages = db
+ .get_channel_messages(channel, 4, Some(messages[0].id))
.await
.unwrap();
+ assert_eq!(
+ prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
+ ["1", "2", "3", "4"]
+ );
}
-
- let messages = db.get_channel_messages(channel, 5, None).await.unwrap();
- assert_eq!(
- messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
- ["5", "6", "7", "8", "9"]
- );
-
- let prev_messages = db
- .get_channel_messages(channel, 4, Some(messages[0].id))
- .await
- .unwrap();
- assert_eq!(
- prev_messages.iter().map(|m| &m.body).collect::<Vec<_>>(),
- ["1", "2", "3", "4"]
- );
}
#[gpui::test]
- async fn test_channel_message_nonces() {
- let test_db = TestDb::new();
- let db = test_db.db();
- let user = db.create_user("user", false).await.unwrap();
- let org = db.create_org("org", "org").await.unwrap();
- let channel = db.create_org_channel(org, "channel").await.unwrap();
-
- let msg1_id = db
- .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
- .await
- .unwrap();
- let msg2_id = db
- .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
- .await
- .unwrap();
- let msg3_id = db
- .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
- .await
- .unwrap();
- let msg4_id = db
- .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
- .await
- .unwrap();
+ async fn test_channel_message_nonces(cx: TestAppContext) {
+ for test_db in [TestDb::postgres(), TestDb::fake(cx.background())] {
+ let db = test_db.db();
+ let user = db.create_user("user", false).await.unwrap();
+ let org = db.create_org("org", "org").await.unwrap();
+ let channel = db.create_org_channel(org, "channel").await.unwrap();
+
+ let msg1_id = db
+ .create_channel_message(channel, user, "1", OffsetDateTime::now_utc(), 1)
+ .await
+ .unwrap();
+ let msg2_id = db
+ .create_channel_message(channel, user, "2", OffsetDateTime::now_utc(), 2)
+ .await
+ .unwrap();
+ let msg3_id = db
+ .create_channel_message(channel, user, "3", OffsetDateTime::now_utc(), 1)
+ .await
+ .unwrap();
+ let msg4_id = db
+ .create_channel_message(channel, user, "4", OffsetDateTime::now_utc(), 2)
+ .await
+ .unwrap();
- assert_ne!(msg1_id, msg2_id);
- assert_eq!(msg1_id, msg3_id);
- assert_eq!(msg2_id, msg4_id);
+ assert_ne!(msg1_id, msg2_id);
+ assert_eq!(msg1_id, msg3_id);
+ assert_eq!(msg2_id, msg4_id);
+ }
}
#[gpui::test]
async fn test_create_access_tokens() {
- let test_db = TestDb::new();
+ let test_db = TestDb::postgres();
let db = test_db.db();
let user = db.create_user("the-user", false).await.unwrap();
@@ -772,4 +764,359 @@ pub mod tests {
&["h5".to_string(), "h4".to_string(), "h3".to_string()]
);
}
+
+ pub struct TestDb {
+ pub db: Option<Arc<dyn Db>>,
+ pub name: String,
+ pub url: String,
+ }
+
+ impl TestDb {
+ pub fn postgres() -> Self {
+ lazy_static! {
+ static ref LOCK: Mutex<()> = Mutex::new(());
+ }
+
+ let _guard = LOCK.lock();
+ let mut rng = StdRng::from_entropy();
+ let name = format!("zed-test-{}", rng.gen::<u128>());
+ let url = format!("postgres://postgres@localhost/{}", name);
+ let migrations_path = Path::new(concat!(env!("CARGO_MANIFEST_DIR"), "/migrations"));
+ let db = block_on(async {
+ Postgres::create_database(&url)
+ .await
+ .expect("failed to create test db");
+ let mut db = PostgresDb::new(&url, 5).await.unwrap();
+ db.test_mode = true;
+ let migrator = Migrator::new(migrations_path).await.unwrap();
+ migrator.run(&db.pool).await.unwrap();
+ db
+ });
+ Self {
+ db: Some(Arc::new(db)),
+ name,
+ url,
+ }
+ }
+
+ pub fn fake(background: Arc<Background>) -> Self {
+ Self {
+ db: Some(Arc::new(FakeDb::new(background))),
+ name: "fake".to_string(),
+ url: "fake".to_string(),
+ }
+ }
+
+ pub fn db(&self) -> &Arc<dyn Db> {
+ self.db.as_ref().unwrap()
+ }
+ }
+
+ impl Drop for TestDb {
+ fn drop(&mut self) {
+ if let Some(db) = self.db.take() {
+ block_on(db.teardown(&self.name, &self.url));
+ }
+ }
+ }
+
+ pub struct FakeDb {
+ background: Arc<Background>,
+ users: Mutex<BTreeMap<UserId, User>>,
+ next_user_id: Mutex<i32>,
+ orgs: Mutex<BTreeMap<OrgId, Org>>,
+ next_org_id: Mutex<i32>,
+ org_memberships: Mutex<BTreeMap<(OrgId, UserId), bool>>,
+ channels: Mutex<BTreeMap<ChannelId, Channel>>,
+ next_channel_id: Mutex<i32>,
+ channel_memberships: Mutex<BTreeMap<(ChannelId, UserId), bool>>,
+ channel_messages: Mutex<BTreeMap<MessageId, ChannelMessage>>,
+ next_channel_message_id: Mutex<i32>,
+ }
+
+ impl FakeDb {
+ pub fn new(background: Arc<Background>) -> Self {
+ Self {
+ background,
+ users: Default::default(),
+ next_user_id: Mutex::new(1),
+ orgs: Default::default(),
+ next_org_id: Mutex::new(1),
+ org_memberships: Default::default(),
+ channels: Default::default(),
+ next_channel_id: Mutex::new(1),
+ channel_memberships: Default::default(),
+ channel_messages: Default::default(),
+ next_channel_message_id: Mutex::new(1),
+ }
+ }
+ }
+
+ #[async_trait]
+ impl Db for FakeDb {
+ async fn create_signup(
+ &self,
+ _github_login: &str,
+ _email_address: &str,
+ _about: &str,
+ _wants_releases: bool,
+ _wants_updates: bool,
+ _wants_community: bool,
+ ) -> Result<SignupId> {
+ unimplemented!()
+ }
+
+ async fn get_all_signups(&self) -> Result<Vec<Signup>> {
+ unimplemented!()
+ }
+
+ async fn destroy_signup(&self, _id: SignupId) -> Result<()> {
+ unimplemented!()
+ }
+
+ async fn create_user(&self, github_login: &str, admin: bool) -> Result<UserId> {
+ self.background.simulate_random_delay().await;
+
+ let mut users = self.users.lock();
+ if let Some(user) = users
+ .values()
+ .find(|user| user.github_login == github_login)
+ {
+ Ok(user.id)
+ } else {
+ let user_id = UserId(post_inc(&mut *self.next_user_id.lock()));
+ users.insert(
+ user_id,
+ User {
+ id: user_id,
+ github_login: github_login.to_string(),
+ admin,
+ },
+ );
+ Ok(user_id)
+ }
+ }
+
+ async fn get_all_users(&self) -> Result<Vec<User>> {
+ unimplemented!()
+ }
+
+ async fn get_user_by_id(&self, id: UserId) -> Result<Option<User>> {
+ Ok(self.get_users_by_ids(vec![id]).await?.into_iter().next())
+ }
+
+ async fn get_users_by_ids(&self, ids: Vec<UserId>) -> Result<Vec<User>> {
+ self.background.simulate_random_delay().await;
+ let users = self.users.lock();
+ Ok(ids.iter().filter_map(|id| users.get(id).cloned()).collect())
+ }
+
+ async fn get_user_by_github_login(&self, _github_login: &str) -> Result<Option<User>> {
+ unimplemented!()
+ }
+
+ async fn set_user_is_admin(&self, _id: UserId, _is_admin: bool) -> Result<()> {
+ unimplemented!()
+ }
+
+ async fn destroy_user(&self, _id: UserId) -> Result<()> {
+ unimplemented!()
+ }
+
+ async fn create_access_token_hash(
+ &self,
+ _user_id: UserId,
+ _access_token_hash: &str,
+ _max_access_token_count: usize,
+ ) -> Result<()> {
+ unimplemented!()
+ }
+
+ async fn get_access_token_hashes(&self, _user_id: UserId) -> Result<Vec<String>> {
+ unimplemented!()
+ }
+
+ async fn find_org_by_slug(&self, _slug: &str) -> Result<Option<Org>> {
+ unimplemented!()
+ }
+
+ async fn create_org(&self, name: &str, slug: &str) -> Result<OrgId> {
+ self.background.simulate_random_delay().await;
+ let mut orgs = self.orgs.lock();
+ if orgs.values().any(|org| org.slug == slug) {
+ Err(anyhow!("org already exists"))
+ } else {
+ let org_id = OrgId(post_inc(&mut *self.next_org_id.lock()));
+ orgs.insert(
+ org_id,
+ Org {
+ id: org_id,
+ name: name.to_string(),
+ slug: slug.to_string(),
+ },
+ );
+ Ok(org_id)
+ }
+ }
+
+ async fn add_org_member(
+ &self,
+ org_id: OrgId,
+ user_id: UserId,
+ is_admin: bool,
+ ) -> Result<()> {
+ self.background.simulate_random_delay().await;
+ if !self.orgs.lock().contains_key(&org_id) {
+ return Err(anyhow!("org does not exist"));
+ }
+ if !self.users.lock().contains_key(&user_id) {
+ return Err(anyhow!("user does not exist"));
+ }
+
+ self.org_memberships
+ .lock()
+ .entry((org_id, user_id))
+ .or_insert(is_admin);
+ Ok(())
+ }
+
+ async fn create_org_channel(&self, org_id: OrgId, name: &str) -> Result<ChannelId> {
+ self.background.simulate_random_delay().await;
+ if !self.orgs.lock().contains_key(&org_id) {
+ return Err(anyhow!("org does not exist"));
+ }
+
+ let mut channels = self.channels.lock();
+ let channel_id = ChannelId(post_inc(&mut *self.next_channel_id.lock()));
+ channels.insert(
+ channel_id,
+ Channel {
+ id: channel_id,
+ name: name.to_string(),
+ owner_id: org_id.0,
+ owner_is_user: false,
+ },
+ );
+ Ok(channel_id)
+ }
+
+ async fn get_org_channels(&self, org_id: OrgId) -> Result<Vec<Channel>> {
+ self.background.simulate_random_delay().await;
+ Ok(self
+ .channels
+ .lock()
+ .values()
+ .filter(|channel| !channel.owner_is_user && channel.owner_id == org_id.0)
+ .cloned()
+ .collect())
+ }
+
+ async fn get_accessible_channels(&self, user_id: UserId) -> Result<Vec<Channel>> {
+ self.background.simulate_random_delay().await;
+ let channels = self.channels.lock();
+ let memberships = self.channel_memberships.lock();
+ Ok(channels
+ .values()
+ .filter(|channel| memberships.contains_key(&(channel.id, user_id)))
+ .cloned()
+ .collect())
+ }
+
+ async fn can_user_access_channel(
+ &self,
+ user_id: UserId,
+ channel_id: ChannelId,
+ ) -> Result<bool> {
+ self.background.simulate_random_delay().await;
+ Ok(self
+ .channel_memberships
+ .lock()
+ .contains_key(&(channel_id, user_id)))
+ }
+
+ async fn add_channel_member(
+ &self,
+ channel_id: ChannelId,
+ user_id: UserId,
+ is_admin: bool,
+ ) -> Result<()> {
+ self.background.simulate_random_delay().await;
+ if !self.channels.lock().contains_key(&channel_id) {
+ return Err(anyhow!("channel does not exist"));
+ }
+ if !self.users.lock().contains_key(&user_id) {
+ return Err(anyhow!("user does not exist"));
+ }
+
+ self.channel_memberships
+ .lock()
+ .entry((channel_id, user_id))
+ .or_insert(is_admin);
+ Ok(())
+ }
+
+ async fn create_channel_message(
+ &self,
+ channel_id: ChannelId,
+ sender_id: UserId,
+ body: &str,
+ timestamp: OffsetDateTime,
+ nonce: u128,
+ ) -> Result<MessageId> {
+ self.background.simulate_random_delay().await;
+ if !self.channels.lock().contains_key(&channel_id) {
+ return Err(anyhow!("channel does not exist"));
+ }
+ if !self.users.lock().contains_key(&sender_id) {
+ return Err(anyhow!("user does not exist"));
+ }
+
+ let mut messages = self.channel_messages.lock();
+ if let Some(message) = messages
+ .values()
+ .find(|message| message.nonce.as_u128() == nonce)
+ {
+ Ok(message.id)
+ } else {
+ let message_id = MessageId(post_inc(&mut *self.next_channel_message_id.lock()));
+ messages.insert(
+ message_id,
+ ChannelMessage {
+ id: message_id,
+ channel_id,
+ sender_id,
+ body: body.to_string(),
+ sent_at: timestamp,
+ nonce: Uuid::from_u128(nonce),
+ },
+ );
+ Ok(message_id)
+ }
+ }
+
+ async fn get_channel_messages(
+ &self,
+ channel_id: ChannelId,
+ count: usize,
+ before_id: Option<MessageId>,
+ ) -> Result<Vec<ChannelMessage>> {
+ let mut messages = self
+ .channel_messages
+ .lock()
+ .values()
+ .rev()
+ .filter(|message| {
+ message.channel_id == channel_id
+ && message.id < before_id.unwrap_or(MessageId::MAX)
+ })
+ .take(count)
+ .cloned()
+ .collect::<Vec<_>>();
+ dbg!(count, before_id, &messages);
+ messages.sort_unstable_by_key(|message| message.id);
+ Ok(messages)
+ }
+
+ async fn teardown(&self, _name: &str, _url: &str) {}
+ }
}
@@ -20,7 +20,7 @@ use anyhow::Result;
use async_std::net::TcpListener;
use async_trait::async_trait;
use auth::RequestExt as _;
-use db::Db;
+use db::{Db, PostgresDb};
use handlebars::{Handlebars, TemplateRenderError};
use parking_lot::RwLock;
use rust_embed::RustEmbed;
@@ -49,7 +49,7 @@ pub struct Config {
}
pub struct AppState {
- db: Db,
+ db: Arc<dyn Db>,
handlebars: RwLock<Handlebars<'static>>,
auth_client: auth::Client,
github_client: Arc<github::AppClient>,
@@ -59,7 +59,7 @@ pub struct AppState {
impl AppState {
async fn new(config: Config) -> tide::Result<Arc<Self>> {
- let db = Db::new(&config.database_url, 5).await?;
+ let db = PostgresDb::new(&config.database_url, 5).await?;
let github_client =
github::AppClient::new(config.github_app_id, config.github_private_key.clone());
let repo_client = github_client
@@ -68,7 +68,7 @@ impl AppState {
.context("failed to initialize github client")?;
let this = Self {
- db,
+ db: Arc::new(db),
handlebars: Default::default(),
auth_client: auth::build_client(&config.github_client_id, &config.github_client_secret),
github_client,
@@ -112,7 +112,7 @@ impl AppState {
#[async_trait]
trait RequestExt {
async fn layout_data(&mut self) -> tide::Result<Arc<LayoutData>>;
- fn db(&self) -> &Db;
+ fn db(&self) -> &Arc<dyn Db>;
}
#[async_trait]
@@ -126,7 +126,7 @@ impl RequestExt for Request {
Ok(self.ext::<Arc<LayoutData>>().unwrap().clone())
}
- fn db(&self) -> &Db {
+ fn db(&self) -> &Arc<dyn Db> {
&self.state().db
}
}
@@ -9,9 +9,8 @@ use anyhow::anyhow;
use async_std::task;
use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
use collections::{HashMap, HashSet};
-use futures::{future::BoxFuture, FutureExt, StreamExt};
+use futures::{channel::mpsc, future::BoxFuture, FutureExt, SinkExt, StreamExt};
use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard};
-use postage::{mpsc, prelude::Sink as _};
use rpc::{
proto::{self, AnyTypedEnvelope, EnvelopedMessage, RequestMessage},
Connection, ConnectionId, Peer, TypedEnvelope,
@@ -38,9 +37,15 @@ pub struct Server {
store: RwLock<Store>,
app_state: Arc<AppState>,
handlers: HashMap<TypeId, MessageHandler>,
- notifications: Option<mpsc::Sender<()>>,
+ notifications: Option<mpsc::UnboundedSender<()>>,
}
+pub trait Executor {
+ fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F);
+}
+
+pub struct RealExecutor;
+
const MESSAGE_COUNT_PER_PAGE: usize = 100;
const MAX_MESSAGE_LEN: usize = 1024;
@@ -48,7 +53,7 @@ impl Server {
pub fn new(
app_state: Arc<AppState>,
peer: Arc<Peer>,
- notifications: Option<mpsc::Sender<()>>,
+ notifications: Option<mpsc::UnboundedSender<()>>,
) -> Arc<Self> {
let mut server = Self {
peer,
@@ -69,7 +74,7 @@ impl Server {
.add_request_handler(Server::register_worktree)
.add_message_handler(Server::unregister_worktree)
.add_request_handler(Server::share_worktree)
- .add_message_handler(Server::update_worktree)
+ .add_request_handler(Server::update_worktree)
.add_message_handler(Server::update_diagnostic_summary)
.add_message_handler(Server::disk_based_diagnostics_updating)
.add_message_handler(Server::disk_based_diagnostics_updated)
@@ -144,12 +149,13 @@ impl Server {
})
}
- pub fn handle_connection(
+ pub fn handle_connection<E: Executor>(
self: &Arc<Self>,
connection: Connection,
addr: String,
user_id: UserId,
- mut send_connection_id: Option<postage::mpsc::Sender<ConnectionId>>,
+ mut send_connection_id: Option<mpsc::Sender<ConnectionId>>,
+ executor: E,
) -> impl Future<Output = ()> {
let mut this = self.clone();
async move {
@@ -183,14 +189,23 @@ impl Server {
let type_name = message.payload_type_name();
log::info!("rpc message received. connection:{}, type:{}", connection_id, type_name);
if let Some(handler) = this.handlers.get(&message.payload_type_id()) {
- if let Err(err) = (handler)(this.clone(), message).await {
- log::error!("rpc message error. connection:{}, type:{}, error:{:?}", connection_id, type_name, err);
+ let notifications = this.notifications.clone();
+ let is_background = message.is_background();
+ let handle_message = (handler)(this.clone(), message);
+ let handle_message = async move {
+ if let Err(err) = handle_message.await {
+ log::error!("rpc message error. connection:{}, type:{}, error:{:?}", connection_id, type_name, err);
+ } else {
+ log::info!("rpc message handled. connection:{}, type:{}, duration:{:?}", connection_id, type_name, start_time.elapsed());
+ }
+ if let Some(mut notifications) = notifications {
+ let _ = notifications.send(()).await;
+ }
+ };
+ if is_background {
+ executor.spawn_detached(handle_message);
} else {
- log::info!("rpc message handled. connection:{}, type:{}, duration:{:?}", connection_id, type_name, start_time.elapsed());
- }
-
- if let Some(mut notifications) = this.notifications.clone() {
- let _ = notifications.send(()).await;
+ handle_message.await;
}
} else {
log::warn!("unhandled message: {}", type_name);
@@ -329,6 +344,7 @@ impl Server {
.cloned()
.collect(),
weak: worktree.weak,
+ next_update_id: share.next_update_id as u64,
})
})
.collect();
@@ -467,6 +483,7 @@ impl Server {
request.sender_id,
entries,
diagnostic_summaries,
+ worktree.next_update_id,
)?;
broadcast(
@@ -485,11 +502,12 @@ impl Server {
async fn update_worktree(
mut self: Arc<Server>,
request: TypedEnvelope<proto::UpdateWorktree>,
- ) -> tide::Result<()> {
+ ) -> tide::Result<proto::Ack> {
let connection_ids = self.state_mut().update_worktree(
request.sender_id,
request.payload.project_id,
request.payload.worktree_id,
+ request.payload.id,
&request.payload.removed_entries,
&request.payload.updated_entries,
)?;
@@ -499,7 +517,7 @@ impl Server {
.forward_send(request.sender_id, connection_id, request.payload.clone())
})?;
- Ok(())
+ Ok(proto::Ack {})
}
async fn update_diagnostic_summary(
@@ -767,7 +785,12 @@ impl Server {
self: Arc<Server>,
request: TypedEnvelope<proto::GetUsers>,
) -> tide::Result<proto::GetUsersResponse> {
- let user_ids = request.payload.user_ids.into_iter().map(UserId::from_proto);
+ let user_ids = request
+ .payload
+ .user_ids
+ .into_iter()
+ .map(UserId::from_proto)
+ .collect();
let users = self
.app_state
.db
@@ -966,6 +989,12 @@ impl Server {
}
}
+impl Executor for RealExecutor {
+ fn spawn_detached<F: 'static + Send + Future<Output = ()>>(&self, future: F) {
+ task::spawn(future);
+ }
+}
+
fn broadcast<F>(
sender_id: ConnectionId,
receiver_ids: Vec<ConnectionId>,
@@ -1032,6 +1061,7 @@ pub fn add_routes(app: &mut tide::Server<Arc<AppState>>, rpc: &Arc<Peer>) {
addr,
user_id,
None,
+ RealExecutor,
)
.await;
}
@@ -1066,15 +1096,17 @@ mod tests {
github, AppState, Config,
};
use ::rpc::Peer;
- use async_std::task;
+ use collections::BTreeMap;
use gpui::{executor, ModelHandle, TestAppContext};
use parking_lot::Mutex;
- use postage::{mpsc, watch};
+ use postage::{sink::Sink, watch};
use rand::prelude::*;
use rpc::PeerId;
use serde_json::json;
use sqlx::types::time::OffsetDateTime;
use std::{
+ cell::{Cell, RefCell},
+ env,
ops::Deref,
path::Path,
rc::Rc,
@@ -1099,7 +1131,7 @@ mod tests {
LanguageConfig, LanguageRegistry, LanguageServerConfig, Point,
},
lsp,
- project::{worktree::WorktreeHandle, DiagnosticSummary, Project, ProjectPath},
+ project::{DiagnosticSummary, Project, ProjectPath},
workspace::{Workspace, WorkspaceParams},
};
@@ -1119,7 +1151,7 @@ mod tests {
cx_a.foreground().forbid_parking();
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground()).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -1257,7 +1289,7 @@ mod tests {
cx_a.foreground().forbid_parking();
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground()).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -1358,7 +1390,7 @@ mod tests {
cx_a.foreground().forbid_parking();
// Connect to a server as 3 clients.
- let mut server = TestServer::start(cx_a.foreground()).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
let client_c = server.create_client(&mut cx_c, "user_c").await;
@@ -1470,11 +1502,6 @@ mod tests {
buffer_b.read_with(&cx_b, |buf, _| assert!(!buf.is_dirty()));
buffer_c.condition(&cx_c, |buf, _| !buf.is_dirty()).await;
- // Ensure worktree observes a/file1's change event *before* the rename occurs, otherwise
- // when interpreting the change event it will mistakenly think that the file has been
- // deleted (because its path has changed) and will subsequently fail to detect the rename.
- worktree_a.flush_fs_events(&cx_a).await;
-
// Make changes on host's file system, see those changes on guest worktrees.
fs.rename(
"/a/file1".as_ref(),
@@ -1483,6 +1510,7 @@ mod tests {
)
.await
.unwrap();
+
fs.rename("/a/file2".as_ref(), "/a/file3".as_ref(), Default::default())
.await
.unwrap();
@@ -1540,7 +1568,7 @@ mod tests {
let fs = Arc::new(FakeFs::new(cx_a.background()));
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground()).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -1590,14 +1618,12 @@ mod tests {
)
.await
.unwrap();
- let worktree_b = project_b.update(&mut cx_b, |p, cx| p.worktrees(cx).next().unwrap());
// Open a buffer as client B
let buffer_b = project_b
.update(&mut cx_b, |p, cx| p.open_buffer((worktree_id, "a.txt"), cx))
.await
.unwrap();
- let mtime = buffer_b.read_with(&cx_b, |buf, _| buf.file().unwrap().mtime());
buffer_b.update(&mut cx_b, |buf, cx| buf.edit([0..0], "world ", cx));
buffer_b.read_with(&cx_b, |buf, _| {
@@ -1609,13 +1635,10 @@ mod tests {
.update(&mut cx_b, |buf, cx| buf.save(cx))
.await
.unwrap();
- worktree_b
- .condition(&cx_b, |_, cx| {
- buffer_b.read(cx).file().unwrap().mtime() != mtime
- })
+ buffer_b
+ .condition(&cx_b, |buffer_b, _| !buffer_b.is_dirty())
.await;
buffer_b.read_with(&cx_b, |buf, _| {
- assert!(!buf.is_dirty());
assert!(!buf.has_conflict());
});
@@ -1633,7 +1656,7 @@ mod tests {
let fs = Arc::new(FakeFs::new(cx_a.background()));
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground()).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -1718,7 +1741,7 @@ mod tests {
let fs = Arc::new(FakeFs::new(cx_a.background()));
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground()).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -1778,10 +1801,12 @@ mod tests {
let buffer_b = cx_b
.background()
.spawn(project_b.update(&mut cx_b, |p, cx| p.open_buffer((worktree_id, "a.txt"), cx)));
- task::yield_now().await;
// Edit the buffer as client A while client B is still opening it.
- buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "z", cx));
+ cx_b.background().simulate_random_delay().await;
+ buffer_a.update(&mut cx_a, |buf, cx| buf.edit([0..0], "X", cx));
+ cx_b.background().simulate_random_delay().await;
+ buffer_a.update(&mut cx_a, |buf, cx| buf.edit([1..1], "Y", cx));
let text = buffer_a.read_with(&cx_a, |buf, _| buf.text());
let buffer_b = buffer_b.await.unwrap();
@@ -1798,7 +1823,7 @@ mod tests {
let fs = Arc::new(FakeFs::new(cx_a.background()));
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground()).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -1873,7 +1898,7 @@ mod tests {
let fs = Arc::new(FakeFs::new(cx_a.background()));
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground()).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -1947,8 +1972,7 @@ mod tests {
let fs = Arc::new(FakeFs::new(cx_a.background()));
// Set up a fake language server.
- let (language_server_config, mut fake_language_server) =
- LanguageServerConfig::fake(&cx_a).await;
+ let (language_server_config, mut fake_language_servers) = LanguageServerConfig::fake();
Arc::get_mut(&mut lang_registry)
.unwrap()
.add(Arc::new(Language::new(
@@ -1962,7 +1986,7 @@ mod tests {
)));
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground()).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -2017,6 +2041,7 @@ mod tests {
.unwrap();
// Simulate a language server reporting errors for a file.
+ let mut fake_language_server = fake_language_servers.next().await.unwrap();
fake_language_server
.notify::<lsp::notification::PublishDiagnostics>(lsp::PublishDiagnosticsParams {
uri: lsp::Url::from_file_path("/a/a.rs").unwrap(),
@@ -2171,18 +2196,14 @@ mod tests {
let fs = Arc::new(FakeFs::new(cx_a.background()));
// Set up a fake language server.
- let (language_server_config, mut fake_language_server) =
- LanguageServerConfig::fake_with_capabilities(
- lsp::ServerCapabilities {
- completion_provider: Some(lsp::CompletionOptions {
- trigger_characters: Some(vec![".".to_string()]),
- ..Default::default()
- }),
- ..Default::default()
- },
- &cx_a,
- )
- .await;
+ let (mut language_server_config, mut fake_language_servers) = LanguageServerConfig::fake();
+ language_server_config.set_fake_capabilities(lsp::ServerCapabilities {
+ completion_provider: Some(lsp::CompletionOptions {
+ trigger_characters: Some(vec![".".to_string()]),
+ ..Default::default()
+ }),
+ ..Default::default()
+ });
Arc::get_mut(&mut lang_registry)
.unwrap()
.add(Arc::new(Language::new(
@@ -2196,7 +2217,7 @@ mod tests {
)));
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground()).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -2264,6 +2285,11 @@ mod tests {
)
});
+ let mut fake_language_server = fake_language_servers.next().await.unwrap();
+ buffer_b
+ .condition(&cx_b, |buffer, _| !buffer.completion_triggers().is_empty())
+ .await;
+
// Type a completion trigger character as the guest.
editor_b.update(&mut cx_b, |editor, cx| {
editor.select_ranges([13..13], None, cx);
@@ -2273,45 +2299,51 @@ mod tests {
// Receive a completion request as the host's language server.
// Return some completions from the host's language server.
- fake_language_server.handle_request::<lsp::request::Completion, _>(|params| {
- assert_eq!(
- params.text_document_position.text_document.uri,
- lsp::Url::from_file_path("/a/main.rs").unwrap(),
- );
- assert_eq!(
- params.text_document_position.position,
- lsp::Position::new(0, 14),
- );
+ cx_a.foreground().start_waiting();
+ fake_language_server
+ .handle_request::<lsp::request::Completion, _>(|params| {
+ assert_eq!(
+ params.text_document_position.text_document.uri,
+ lsp::Url::from_file_path("/a/main.rs").unwrap(),
+ );
+ assert_eq!(
+ params.text_document_position.position,
+ lsp::Position::new(0, 14),
+ );
- Some(lsp::CompletionResponse::Array(vec![
- lsp::CompletionItem {
- label: "first_method(…)".into(),
- detail: Some("fn(&mut self, B) -> C".into()),
- text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit {
- new_text: "first_method($1)".to_string(),
- range: lsp::Range::new(
- lsp::Position::new(0, 14),
- lsp::Position::new(0, 14),
- ),
- })),
- insert_text_format: Some(lsp::InsertTextFormat::SNIPPET),
- ..Default::default()
- },
- lsp::CompletionItem {
- label: "second_method(…)".into(),
- detail: Some("fn(&mut self, C) -> D<E>".into()),
- text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit {
- new_text: "second_method()".to_string(),
- range: lsp::Range::new(
- lsp::Position::new(0, 14),
- lsp::Position::new(0, 14),
- ),
- })),
- insert_text_format: Some(lsp::InsertTextFormat::SNIPPET),
- ..Default::default()
- },
- ]))
- });
+ Some(lsp::CompletionResponse::Array(vec![
+ lsp::CompletionItem {
+ label: "first_method(…)".into(),
+ detail: Some("fn(&mut self, B) -> C".into()),
+ text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit {
+ new_text: "first_method($1)".to_string(),
+ range: lsp::Range::new(
+ lsp::Position::new(0, 14),
+ lsp::Position::new(0, 14),
+ ),
+ })),
+ insert_text_format: Some(lsp::InsertTextFormat::SNIPPET),
+ ..Default::default()
+ },
+ lsp::CompletionItem {
+ label: "second_method(…)".into(),
+ detail: Some("fn(&mut self, C) -> D<E>".into()),
+ text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit {
+ new_text: "second_method()".to_string(),
+ range: lsp::Range::new(
+ lsp::Position::new(0, 14),
+ lsp::Position::new(0, 14),
+ ),
+ })),
+ insert_text_format: Some(lsp::InsertTextFormat::SNIPPET),
+ ..Default::default()
+ },
+ ]))
+ })
+ .next()
+ .await
+ .unwrap();
+ cx_a.foreground().finish_waiting();
// Open the buffer on the host.
let buffer_a = project_a
@@ -2325,9 +2357,10 @@ mod tests {
.await;
// Confirm a completion on the guest.
- editor_b.next_notification(&cx_b).await;
+ editor_b
+ .condition(&cx_b, |editor, _| editor.context_menu_visible())
+ .await;
editor_b.update(&mut cx_b, |editor, cx| {
- assert!(editor.context_menu_visible());
editor.confirm_completion(&ConfirmCompletion(Some(0)), cx);
assert_eq!(editor.text(cx), "fn main() { a.first_method() }");
});
@@ -2352,22 +2385,17 @@ mod tests {
}
});
+ // The additional edit is applied.
buffer_a
.condition(&cx_a, |buffer, _| {
- buffer.text() == "fn main() { a.first_method() }"
+ buffer.text() == "use d::SomeTrait;\nfn main() { a.first_method() }"
})
.await;
-
- // The additional edit is applied.
buffer_b
.condition(&cx_b, |buffer, _| {
buffer.text() == "use d::SomeTrait;\nfn main() { a.first_method() }"
})
.await;
- assert_eq!(
- buffer_a.read_with(&cx_a, |buffer, _| buffer.text()),
- buffer_b.read_with(&cx_b, |buffer, _| buffer.text()),
- );
}
#[gpui::test(iterations = 10)]
@@ -2377,8 +2405,7 @@ mod tests {
let fs = Arc::new(FakeFs::new(cx_a.background()));
// Set up a fake language server.
- let (language_server_config, mut fake_language_server) =
- LanguageServerConfig::fake(&cx_a).await;
+ let (language_server_config, mut fake_language_servers) = LanguageServerConfig::fake();
Arc::get_mut(&mut lang_registry)
.unwrap()
.add(Arc::new(Language::new(
@@ -2392,7 +2419,7 @@ mod tests {
)));
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground()).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -2452,6 +2479,7 @@ mod tests {
project.format(HashSet::from_iter([buffer_b.clone()]), true, cx)
});
+ let mut fake_language_server = fake_language_servers.next().await.unwrap();
fake_language_server.handle_request::<lsp::request::Formatting, _>(|_| {
Some(vec![
lsp::TextEdit {
@@ -2494,8 +2522,7 @@ mod tests {
.await;
// Set up a fake language server.
- let (language_server_config, mut fake_language_server) =
- LanguageServerConfig::fake(&cx_a).await;
+ let (language_server_config, mut fake_language_servers) = LanguageServerConfig::fake();
Arc::get_mut(&mut lang_registry)
.unwrap()
.add(Arc::new(Language::new(
@@ -2509,7 +2536,7 @@ mod tests {
)));
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground()).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -2560,6 +2587,8 @@ mod tests {
// Request the definition of a symbol as the guest.
let definitions_1 = project_b.update(&mut cx_b, |p, cx| p.definition(&buffer_b, 23, cx));
+
+ let mut fake_language_server = fake_language_servers.next().await.unwrap();
fake_language_server.handle_request::<lsp::request::GotoDefinition, _>(|_| {
Some(lsp::GotoDefinitionResponse::Scalar(lsp::Location::new(
lsp::Url::from_file_path("/root-2/b.rs").unwrap(),
@@ -2640,8 +2669,8 @@ mod tests {
.await;
// Set up a fake language server.
- let (language_server_config, mut fake_language_server) =
- LanguageServerConfig::fake(&cx_a).await;
+ let (language_server_config, mut fake_language_servers) = LanguageServerConfig::fake();
+
Arc::get_mut(&mut lang_registry)
.unwrap()
.add(Arc::new(Language::new(
@@ -2655,7 +2684,7 @@ mod tests {
)));
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground()).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -2669,6 +2698,7 @@ mod tests {
cx,
)
});
+
let (worktree_a, _) = project_a
.update(&mut cx_a, |p, cx| {
p.find_or_create_local_worktree("/root", false, cx)
@@ -2715,6 +2745,7 @@ mod tests {
definitions = project_b.update(&mut cx_b, |p, cx| p.definition(&buffer_b1, 23, cx));
}
+ let mut fake_language_server = fake_language_servers.next().await.unwrap();
fake_language_server.handle_request::<lsp::request::GotoDefinition, _>(|_| {
Some(lsp::GotoDefinitionResponse::Scalar(lsp::Location::new(
lsp::Url::from_file_path("/root/b.rs").unwrap(),
@@ -2740,14 +2771,7 @@ mod tests {
cx_b.update(|cx| editor::init(cx, &mut path_openers_b));
// Set up a fake language server.
- let (language_server_config, mut fake_language_server) =
- LanguageServerConfig::fake_with_capabilities(
- lsp::ServerCapabilities {
- ..Default::default()
- },
- &cx_a,
- )
- .await;
+ let (language_server_config, mut fake_language_servers) = LanguageServerConfig::fake();
Arc::get_mut(&mut lang_registry)
.unwrap()
.add(Arc::new(Language::new(
@@ -2761,7 +2785,7 @@ mod tests {
)));
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground()).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -2827,6 +2851,8 @@ mod tests {
.unwrap()
.downcast::<Editor>()
.unwrap();
+
+ let mut fake_language_server = fake_language_servers.next().await.unwrap();
fake_language_server
.handle_request::<lsp::request::CodeActionRequest, _>(|params| {
assert_eq!(
@@ -2845,58 +2871,62 @@ mod tests {
editor.select_ranges([Point::new(1, 31)..Point::new(1, 31)], None, cx);
cx.focus(&editor_b);
});
- fake_language_server.handle_request::<lsp::request::CodeActionRequest, _>(|params| {
- assert_eq!(
- params.text_document.uri,
- lsp::Url::from_file_path("/a/main.rs").unwrap(),
- );
- assert_eq!(params.range.start, lsp::Position::new(1, 31));
- assert_eq!(params.range.end, lsp::Position::new(1, 31));
-
- Some(vec![lsp::CodeActionOrCommand::CodeAction(
- lsp::CodeAction {
- title: "Inline into all callers".to_string(),
- edit: Some(lsp::WorkspaceEdit {
- changes: Some(
- [
- (
- lsp::Url::from_file_path("/a/main.rs").unwrap(),
- vec![lsp::TextEdit::new(
- lsp::Range::new(
- lsp::Position::new(1, 22),
- lsp::Position::new(1, 34),
- ),
- "4".to_string(),
- )],
- ),
- (
- lsp::Url::from_file_path("/a/other.rs").unwrap(),
- vec![lsp::TextEdit::new(
- lsp::Range::new(
- lsp::Position::new(0, 0),
- lsp::Position::new(0, 27),
- ),
- "".to_string(),
- )],
- ),
- ]
- .into_iter()
- .collect(),
- ),
- ..Default::default()
- }),
- data: Some(json!({
- "codeActionParams": {
- "range": {
- "start": {"line": 1, "column": 31},
- "end": {"line": 1, "column": 31},
+
+ fake_language_server
+ .handle_request::<lsp::request::CodeActionRequest, _>(|params| {
+ assert_eq!(
+ params.text_document.uri,
+ lsp::Url::from_file_path("/a/main.rs").unwrap(),
+ );
+ assert_eq!(params.range.start, lsp::Position::new(1, 31));
+ assert_eq!(params.range.end, lsp::Position::new(1, 31));
+
+ Some(vec![lsp::CodeActionOrCommand::CodeAction(
+ lsp::CodeAction {
+ title: "Inline into all callers".to_string(),
+ edit: Some(lsp::WorkspaceEdit {
+ changes: Some(
+ [
+ (
+ lsp::Url::from_file_path("/a/main.rs").unwrap(),
+ vec![lsp::TextEdit::new(
+ lsp::Range::new(
+ lsp::Position::new(1, 22),
+ lsp::Position::new(1, 34),
+ ),
+ "4".to_string(),
+ )],
+ ),
+ (
+ lsp::Url::from_file_path("/a/other.rs").unwrap(),
+ vec![lsp::TextEdit::new(
+ lsp::Range::new(
+ lsp::Position::new(0, 0),
+ lsp::Position::new(0, 27),
+ ),
+ "".to_string(),
+ )],
+ ),
+ ]
+ .into_iter()
+ .collect(),
+ ),
+ ..Default::default()
+ }),
+ data: Some(json!({
+ "codeActionParams": {
+ "range": {
+ "start": {"line": 1, "column": 31},
+ "end": {"line": 1, "column": 31},
+ }
}
- }
- })),
- ..Default::default()
- },
- )])
- });
+ })),
+ ..Default::default()
+ },
+ )])
+ })
+ .next()
+ .await;
// Toggle code actions and wait for them to display.
editor_b.update(&mut cx_b, |editor, cx| {
@@ -2906,6 +2936,8 @@ mod tests {
.condition(&cx_b, |editor, _| editor.context_menu_visible())
.await;
+ fake_language_server.remove_request_handler::<lsp::request::CodeActionRequest>();
+
// Confirming the code action will trigger a resolve request.
let confirm_action = workspace_b
.update(&mut cx_b, |workspace, cx| {
@@ -2974,7 +3006,7 @@ mod tests {
cx_a.foreground().forbid_parking();
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground()).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
@@ -3113,7 +3145,7 @@ mod tests {
async fn test_chat_message_validation(mut cx_a: TestAppContext) {
cx_a.foreground().forbid_parking();
- let mut server = TestServer::start(cx_a.foreground()).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let db = &server.app_state.db;
@@ -3174,7 +3206,7 @@ mod tests {
cx_a.foreground().forbid_parking();
// Connect to a server as 2 clients.
- let mut server = TestServer::start(cx_a.foreground()).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
let mut status_b = client_b.status();
@@ -3392,7 +3424,7 @@ mod tests {
let fs = Arc::new(FakeFs::new(cx_a.background()));
// Connect to a server as 3 clients.
- let mut server = TestServer::start(cx_a.foreground()).await;
+ let mut server = TestServer::start(cx_a.foreground(), cx_a.background()).await;
let client_a = server.create_client(&mut cx_a, "user_a").await;
let client_b = server.create_client(&mut cx_b, "user_b").await;
let client_c = server.create_client(&mut cx_c, "user_c").await;
@@ -3523,23 +3555,269 @@ mod tests {
}
}
+ #[gpui::test(iterations = 100)]
+ async fn test_random_collaboration(cx: TestAppContext, rng: StdRng) {
+ cx.foreground().forbid_parking();
+ let max_peers = env::var("MAX_PEERS")
+ .map(|i| i.parse().expect("invalid `MAX_PEERS` variable"))
+ .unwrap_or(5);
+ let max_operations = env::var("OPERATIONS")
+ .map(|i| i.parse().expect("invalid `OPERATIONS` variable"))
+ .unwrap_or(10);
+
+ let rng = Rc::new(RefCell::new(rng));
+
+ let mut host_lang_registry = Arc::new(LanguageRegistry::new());
+ let guest_lang_registry = Arc::new(LanguageRegistry::new());
+
+ // Set up a fake language server.
+ let (mut language_server_config, _fake_language_servers) = LanguageServerConfig::fake();
+ language_server_config.set_fake_initializer(|fake_server| {
+ fake_server.handle_request::<lsp::request::Completion, _>(|_| {
+ Some(lsp::CompletionResponse::Array(vec![lsp::CompletionItem {
+ text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit {
+ range: lsp::Range::new(lsp::Position::new(0, 0), lsp::Position::new(0, 0)),
+ new_text: "the-new-text".to_string(),
+ })),
+ ..Default::default()
+ }]))
+ });
+
+ fake_server.handle_request::<lsp::request::CodeActionRequest, _>(|_| {
+ Some(vec![lsp::CodeActionOrCommand::CodeAction(
+ lsp::CodeAction {
+ title: "the-code-action".to_string(),
+ ..Default::default()
+ },
+ )])
+ });
+ });
+
+ Arc::get_mut(&mut host_lang_registry)
+ .unwrap()
+ .add(Arc::new(Language::new(
+ LanguageConfig {
+ name: "Rust".to_string(),
+ path_suffixes: vec!["rs".to_string()],
+ language_server: Some(language_server_config),
+ ..Default::default()
+ },
+ None,
+ )));
+
+ let fs = Arc::new(FakeFs::new(cx.background()));
+ fs.insert_tree(
+ "/_collab",
+ json!({
+ ".zed.toml": r#"collaborators = ["guest-1", "guest-2", "guest-3", "guest-4", "guest-5"]"#
+ }),
+ )
+ .await;
+
+ let operations = Rc::new(Cell::new(0));
+ let mut server = TestServer::start(cx.foreground(), cx.background()).await;
+ let mut clients = Vec::new();
+
+ let mut next_entity_id = 100000;
+ let mut host_cx = TestAppContext::new(
+ cx.foreground_platform(),
+ cx.platform(),
+ cx.foreground(),
+ cx.background(),
+ cx.font_cache(),
+ next_entity_id,
+ );
+ let host = server.create_client(&mut host_cx, "host").await;
+ let host_project = host_cx.update(|cx| {
+ Project::local(
+ host.client.clone(),
+ host.user_store.clone(),
+ host_lang_registry.clone(),
+ fs.clone(),
+ cx,
+ )
+ });
+ let host_project_id = host_project
+ .update(&mut host_cx, |p, _| p.next_remote_id())
+ .await;
+
+ let (collab_worktree, _) = host_project
+ .update(&mut host_cx, |project, cx| {
+ project.find_or_create_local_worktree("/_collab", false, cx)
+ })
+ .await
+ .unwrap();
+ collab_worktree
+ .read_with(&host_cx, |tree, _| tree.as_local().unwrap().scan_complete())
+ .await;
+ host_project
+ .update(&mut host_cx, |project, cx| project.share(cx))
+ .await
+ .unwrap();
+
+ clients.push(cx.foreground().spawn(host.simulate_host(
+ host_project.clone(),
+ operations.clone(),
+ max_operations,
+ rng.clone(),
+ host_cx.clone(),
+ )));
+
+ while operations.get() < max_operations {
+ cx.background().simulate_random_delay().await;
+ if clients.len() < max_peers && rng.borrow_mut().gen_bool(0.05) {
+ operations.set(operations.get() + 1);
+
+ let guest_id = clients.len();
+ log::info!("Adding guest {}", guest_id);
+ next_entity_id += 100000;
+ let mut guest_cx = TestAppContext::new(
+ cx.foreground_platform(),
+ cx.platform(),
+ cx.foreground(),
+ cx.background(),
+ cx.font_cache(),
+ next_entity_id,
+ );
+ let guest = server
+ .create_client(&mut guest_cx, &format!("guest-{}", guest_id))
+ .await;
+ let guest_project = Project::remote(
+ host_project_id,
+ guest.client.clone(),
+ guest.user_store.clone(),
+ guest_lang_registry.clone(),
+ fs.clone(),
+ &mut guest_cx.to_async(),
+ )
+ .await
+ .unwrap();
+ clients.push(cx.foreground().spawn(guest.simulate_guest(
+ guest_id,
+ guest_project,
+ operations.clone(),
+ max_operations,
+ rng.clone(),
+ guest_cx,
+ )));
+
+ log::info!("Guest {} added", guest_id);
+ }
+ }
+
+ let clients = futures::future::join_all(clients).await;
+ cx.foreground().run_until_parked();
+
+ let host_worktree_snapshots = host_project.read_with(&host_cx, |project, cx| {
+ project
+ .worktrees(cx)
+ .map(|worktree| {
+ let snapshot = worktree.read(cx).snapshot();
+ (snapshot.id(), snapshot)
+ })
+ .collect::<BTreeMap<_, _>>()
+ });
+
+ for (guest_client, guest_cx) in clients.iter().skip(1) {
+ let guest_id = guest_client.client.id();
+ let worktree_snapshots =
+ guest_client
+ .project
+ .as_ref()
+ .unwrap()
+ .read_with(guest_cx, |project, cx| {
+ project
+ .worktrees(cx)
+ .map(|worktree| {
+ let worktree = worktree.read(cx);
+ assert!(
+ !worktree.as_remote().unwrap().has_pending_updates(),
+ "Guest {} worktree {:?} contains deferred updates",
+ guest_id,
+ worktree.id()
+ );
+ (worktree.id(), worktree.snapshot())
+ })
+ .collect::<BTreeMap<_, _>>()
+ });
+
+ assert_eq!(
+ worktree_snapshots.keys().collect::<Vec<_>>(),
+ host_worktree_snapshots.keys().collect::<Vec<_>>(),
+ "guest {} has different worktrees than the host",
+ guest_id
+ );
+ for (id, host_snapshot) in &host_worktree_snapshots {
+ let guest_snapshot = &worktree_snapshots[id];
+ assert_eq!(
+ guest_snapshot.root_name(),
+ host_snapshot.root_name(),
+ "guest {} has different root name than the host for worktree {}",
+ guest_id,
+ id
+ );
+ assert_eq!(
+ guest_snapshot.entries(false).collect::<Vec<_>>(),
+ host_snapshot.entries(false).collect::<Vec<_>>(),
+ "guest {} has different snapshot than the host for worktree {}",
+ guest_id,
+ id
+ );
+ }
+
+ guest_client
+ .project
+ .as_ref()
+ .unwrap()
+ .read_with(guest_cx, |project, _| {
+ assert!(
+ !project.has_buffered_operations(),
+ "guest {} has buffered operations ",
+ guest_id,
+ );
+ });
+
+ for guest_buffer in &guest_client.buffers {
+ let buffer_id = guest_buffer.read_with(guest_cx, |buffer, _| buffer.remote_id());
+ let host_buffer = host_project.read_with(&host_cx, |project, _| {
+ project
+ .shared_buffer(guest_client.peer_id, buffer_id)
+ .expect(&format!(
+ "host doest not have buffer for guest:{}, peer:{}, id:{}",
+ guest_id, guest_client.peer_id, buffer_id
+ ))
+ });
+ assert_eq!(
+ guest_buffer.read_with(guest_cx, |buffer, _| buffer.text()),
+ host_buffer.read_with(&host_cx, |buffer, _| buffer.text()),
+ "guest {} buffer {} differs from the host's buffer",
+ guest_id,
+ buffer_id,
+ );
+ }
+ }
+ }
+
struct TestServer {
peer: Arc<Peer>,
app_state: Arc<AppState>,
server: Arc<Server>,
foreground: Rc<executor::Foreground>,
- notifications: mpsc::Receiver<()>,
+ notifications: mpsc::UnboundedReceiver<()>,
connection_killers: Arc<Mutex<HashMap<UserId, watch::Sender<Option<()>>>>>,
forbid_connections: Arc<AtomicBool>,
_test_db: TestDb,
}
impl TestServer {
- async fn start(foreground: Rc<executor::Foreground>) -> Self {
- let test_db = TestDb::new();
+ async fn start(
+ foreground: Rc<executor::Foreground>,
+ background: Arc<executor::Background>,
+ ) -> Self {
+ let test_db = TestDb::fake(background);
let app_state = Self::build_app_state(&test_db).await;
let peer = Peer::new();
- let notifications = mpsc::channel(128);
+ let notifications = mpsc::unbounded();
let server = Server::new(app_state.clone(), peer.clone(), Some(notifications.0));
Self {
peer,
@@ -43,6 +43,7 @@ pub struct ProjectShare {
pub struct WorktreeShare {
pub entries: HashMap<u64, proto::Entry>,
pub diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
+ pub next_update_id: u64,
}
#[derive(Default)]
@@ -403,6 +404,7 @@ impl Store {
connection_id: ConnectionId,
entries: HashMap<u64, proto::Entry>,
diagnostic_summaries: BTreeMap<PathBuf, proto::DiagnosticSummary>,
+ next_update_id: u64,
) -> tide::Result<SharedWorktree> {
let project = self
.projects
@@ -416,6 +418,7 @@ impl Store {
worktree.share = Some(WorktreeShare {
entries,
diagnostic_summaries,
+ next_update_id,
});
Ok(SharedWorktree {
authorized_user_ids: project.authorized_user_ids(),
@@ -534,6 +537,7 @@ impl Store {
connection_id: ConnectionId,
project_id: u64,
worktree_id: u64,
+ update_id: u64,
removed_entries: &[u64],
updated_entries: &[proto::Entry],
) -> tide::Result<Vec<ConnectionId>> {
@@ -545,6 +549,11 @@ impl Store {
.share
.as_mut()
.ok_or_else(|| anyhow!("worktree is not shared"))?;
+ if share.next_update_id != update_id {
+ return Err(anyhow!("received worktree updates out-of-order"))?;
+ }
+
+ share.next_update_id = update_id + 1;
for entry_id in removed_entries {
share.entries.remove(&entry_id);
}
@@ -478,6 +478,14 @@ impl<T: Item> SumTree<T> {
}
}
+impl<T: Item + PartialEq> PartialEq for SumTree<T> {
+ fn eq(&self, other: &Self) -> bool {
+ self.iter().eq(other.iter())
+ }
+}
+
+impl<T: Item + Eq> Eq for SumTree<T> {}
+
impl<T: KeyedItem> SumTree<T> {
pub fn insert_or_replace(&mut self, item: T, cx: &<T::Summary as Summary>::Context) -> bool {
let mut replaced = false;
@@ -21,7 +21,7 @@ use operation_queue::OperationQueue;
pub use patch::Patch;
pub use point::*;
pub use point_utf16::*;
-use postage::{oneshot, prelude::*};
+use postage::{barrier, oneshot, prelude::*};
#[cfg(any(test, feature = "test-support"))]
pub use random_char_iter::*;
use rope::TextDimension;
@@ -53,6 +53,7 @@ pub struct Buffer {
pub lamport_clock: clock::Lamport,
subscriptions: Topic,
edit_id_resolvers: HashMap<clock::Local, Vec<oneshot::Sender<()>>>,
+ version_barriers: Vec<(clock::Global, barrier::Sender)>,
}
#[derive(Clone, Debug)]
@@ -574,6 +575,7 @@ impl Buffer {
lamport_clock,
subscriptions: Default::default(),
edit_id_resolvers: Default::default(),
+ version_barriers: Default::default(),
}
}
@@ -835,6 +837,8 @@ impl Buffer {
}
}
}
+ self.version_barriers
+ .retain(|(version, _)| !self.snapshot.version().observed_all(version));
Ok(())
}
@@ -1305,6 +1309,16 @@ impl Buffer {
}
}
+ pub fn wait_for_version(&mut self, version: clock::Global) -> impl Future<Output = ()> {
+ let (tx, mut rx) = barrier::channel();
+ if !self.snapshot.version.observed_all(&version) {
+ self.version_barriers.push((version, tx));
+ }
+ async move {
+ rx.recv().await;
+ }
+ }
+
fn resolve_edit(&mut self, edit_id: clock::Local) {
for mut tx in self
.edit_id_resolvers
@@ -4,13 +4,14 @@ pub mod test;
use futures::Future;
use std::{
cmp::Ordering,
+ ops::AddAssign,
pin::Pin,
task::{Context, Poll},
};
-pub fn post_inc(value: &mut usize) -> usize {
+pub fn post_inc<T: From<u8> + AddAssign<T> + Copy>(value: &mut T) -> T {
let prev = *value;
- *value += 1;
+ *value += T::from(1);
prev
}
@@ -17,9 +17,9 @@ use gpui::{
json::{self, to_string_pretty, ToJson},
keymap::Binding,
platform::{CursorStyle, WindowOptions},
- AnyModelHandle, AnyViewHandle, AppContext, ClipboardItem, Entity, ModelContext, ModelHandle,
- MutableAppContext, PathPromptOptions, PromptLevel, RenderContext, Task, View, ViewContext,
- ViewHandle, WeakModelHandle, WeakViewHandle,
+ AnyModelHandle, AnyViewHandle, AppContext, ClipboardItem, Entity, ImageData, ModelContext,
+ ModelHandle, MutableAppContext, PathPromptOptions, PromptLevel, RenderContext, Task, View,
+ ViewContext, ViewHandle, WeakModelHandle, WeakViewHandle,
};
use language::LanguageRegistry;
use log::error;
@@ -1139,7 +1139,7 @@ impl Workspace {
Flex::row()
.with_children(self.render_share_icon(theme, cx))
.with_children(self.render_collaborators(theme, cx))
- .with_child(self.render_avatar(
+ .with_child(self.render_current_user(
self.user_store.read(cx).current_user().as_ref(),
self.project.read(cx).replica_id(),
theme,
@@ -1171,13 +1171,17 @@ impl Workspace {
collaborators.sort_unstable_by_key(|collaborator| collaborator.replica_id);
collaborators
.into_iter()
- .map(|collaborator| {
- self.render_avatar(Some(&collaborator.user), collaborator.replica_id, theme, cx)
+ .filter_map(|collaborator| {
+ Some(self.render_avatar(
+ collaborator.user.avatar.clone()?,
+ collaborator.replica_id,
+ theme,
+ ))
})
.collect()
}
- fn render_avatar(
+ fn render_current_user(
&self,
user: Option<&Arc<User>>,
replica_id: ReplicaId,
@@ -1185,33 +1189,9 @@ impl Workspace {
cx: &mut RenderContext<Self>,
) -> ElementBox {
if let Some(avatar) = user.and_then(|user| user.avatar.clone()) {
- ConstrainedBox::new(
- Stack::new()
- .with_child(
- ConstrainedBox::new(
- Image::new(avatar)
- .with_style(theme.workspace.titlebar.avatar)
- .boxed(),
- )
- .with_width(theme.workspace.titlebar.avatar_width)
- .aligned()
- .boxed(),
- )
- .with_child(
- AvatarRibbon::new(theme.editor.replica_selection_style(replica_id).cursor)
- .constrained()
- .with_width(theme.workspace.titlebar.avatar_ribbon.width)
- .with_height(theme.workspace.titlebar.avatar_ribbon.height)
- .aligned()
- .bottom()
- .boxed(),
- )
- .boxed(),
- )
- .with_width(theme.workspace.right_sidebar.width)
- .boxed()
+ self.render_avatar(avatar, replica_id, theme)
} else {
- MouseEventHandler::new::<Authenticate, _, _, _>(0, cx, |state, _| {
+ MouseEventHandler::new::<Authenticate, _, _, _>(cx.view_id(), cx, |state, _| {
let style = if state.hovered {
&theme.workspace.titlebar.hovered_sign_in_prompt
} else {
@@ -1229,6 +1209,39 @@ impl Workspace {
}
}
+ fn render_avatar(
+ &self,
+ avatar: Arc<ImageData>,
+ replica_id: ReplicaId,
+ theme: &Theme,
+ ) -> ElementBox {
+ ConstrainedBox::new(
+ Stack::new()
+ .with_child(
+ ConstrainedBox::new(
+ Image::new(avatar)
+ .with_style(theme.workspace.titlebar.avatar)
+ .boxed(),
+ )
+ .with_width(theme.workspace.titlebar.avatar_width)
+ .aligned()
+ .boxed(),
+ )
+ .with_child(
+ AvatarRibbon::new(theme.editor.replica_selection_style(replica_id).cursor)
+ .constrained()
+ .with_width(theme.workspace.titlebar.avatar_ribbon.width)
+ .with_height(theme.workspace.titlebar.avatar_ribbon.height)
+ .aligned()
+ .bottom()
+ .boxed(),
+ )
+ .boxed(),
+ )
+ .with_width(theme.workspace.right_sidebar.width)
+ .boxed()
+ }
+
fn render_share_icon(&self, theme: &Theme, cx: &mut RenderContext<Self>) -> Option<ElementBox> {
if self.project().read(cx).is_local() && self.client.user_id().is_some() {
enum Share {}
@@ -53,6 +53,8 @@ fn main() {
let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http.clone(), cx));
let mut path_openers = Vec::new();
+ project::Project::init(&client);
+ client::Channel::init(&client);
client::init(client.clone(), cx);
workspace::init(cx);
editor::init(cx, &mut path_openers);