From 310def292340933b5a462fcd99d9e77cfddb899d Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 12 Jan 2022 18:01:20 +0100 Subject: [PATCH 1/5] Implement `Buffer::format` Co-Authored-By: Nathan Sobo --- crates/client/src/client.rs | 8 ++ crates/editor/src/items.rs | 17 +++- crates/editor/src/multi_buffer.rs | 14 ++++ crates/language/src/buffer.rs | 67 ++++++++++++++- crates/language/src/language.rs | 18 ++++- crates/lsp/src/lsp.rs | 30 ++++--- crates/project/src/project.rs | 16 ++++ crates/project/src/worktree.rs | 87 +++++++++++++++----- crates/rpc/proto/zed.proto | 39 +++++---- crates/rpc/src/peer.rs | 8 +- crates/rpc/src/proto.rs | 3 + crates/server/src/rpc.rs | 130 ++++++++++++++++++++++++++++++ crates/text/src/rope.rs | 6 ++ 13 files changed, 387 insertions(+), 56 deletions(-) diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index efda8097707033106ce15b86aa3b6cfe88a9ea1d..f8512b6550c0c45f2f9e5ef6baf2562e4703638c 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -691,6 +691,14 @@ impl Client { ) -> impl Future> { self.peer.respond(receipt, response) } + + pub fn respond_with_error( + &self, + receipt: Receipt, + error: proto::Error, + ) -> impl Future> { + self.peer.respond_with_error(receipt, error) + } } fn read_credentials_from_keychain(cx: &AsyncAppContext) -> Option { diff --git a/crates/editor/src/items.rs b/crates/editor/src/items.rs index b97f01ce69608944a229a07eadb496520d380e9d..8abcc76ef9fa20290e3919b1ee295e3daf666aed 100644 --- a/crates/editor/src/items.rs +++ b/crates/editor/src/items.rs @@ -1,4 +1,4 @@ -use crate::{Editor, Event}; +use crate::{Autoscroll, Editor, Event}; use crate::{MultiBuffer, ToPoint as _}; use anyhow::Result; use gpui::{ @@ -11,6 +11,7 @@ use project::{File, ProjectPath, Worktree}; use std::fmt::Write; use std::path::Path; use text::{Point, Selection}; +use util::TryFutureExt; use workspace::{ ItemHandle, ItemView, ItemViewHandle, PathOpener, Settings, StatusItemView, WeakItemHandle, Workspace, @@ -141,9 +142,17 @@ impl ItemView for Editor { } fn save(&mut self, cx: &mut ViewContext) -> Result>> { - let save = self.buffer().update(cx, |b, cx| b.save(cx))?; - Ok(cx.spawn(|_, _| async move { - save.await?; + let buffer = self.buffer().clone(); + Ok(cx.spawn(|editor, mut cx| async move { + buffer + .update(&mut cx, |buffer, cx| buffer.format(cx).log_err()) + .await; + editor.update(&mut cx, |editor, cx| { + editor.request_autoscroll(Autoscroll::Fit, cx) + }); + buffer + .update(&mut cx, |buffer, cx| buffer.save(cx))? + .await?; Ok(()) })) } diff --git a/crates/editor/src/multi_buffer.rs b/crates/editor/src/multi_buffer.rs index cd4a3207df053a1d27284b5993f3a3a495b36d9d..c7192cd622c51dbd43dbf58f5f561e045f64a2b9 100644 --- a/crates/editor/src/multi_buffer.rs +++ b/crates/editor/src/multi_buffer.rs @@ -798,6 +798,20 @@ impl MultiBuffer { cx.emit(event.clone()); } + pub fn format(&mut self, cx: &mut ModelContext) -> Task> { + let mut format_tasks = Vec::new(); + for BufferState { buffer, .. } in self.buffers.borrow().values() { + format_tasks.push(buffer.update(cx, |buffer, cx| buffer.format(cx))); + } + + cx.spawn(|_, _| async move { + for format in format_tasks { + format.await?; + } + Ok(()) + }) + } + pub fn save(&mut self, cx: &mut ModelContext) -> Result>> { let mut save_tasks = Vec::new(); for BufferState { buffer, .. } in self.buffers.borrow().values() { diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index c0c3d9c5f149894a3f8fdb16138c1362a1ca59da..e6b593f70d63f3c9411336a1faffca3e53f898bd 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -1,10 +1,13 @@ -use crate::diagnostic_set::{DiagnosticEntry, DiagnosticGroup}; pub use crate::{ diagnostic_set::DiagnosticSet, highlight_map::{HighlightId, HighlightMap}, proto, BracketPair, Grammar, Language, LanguageConfig, LanguageRegistry, LanguageServerConfig, PLAIN_TEXT, }; +use crate::{ + diagnostic_set::{DiagnosticEntry, DiagnosticGroup}, + range_from_lsp, ToPointUtf16, +}; use anyhow::{anyhow, Result}; use clock::ReplicaId; use futures::FutureExt as _; @@ -180,6 +183,9 @@ pub trait File { fn load_local(&self, cx: &AppContext) -> Option>>; + fn format_remote(&self, buffer_id: u64, cx: &mut MutableAppContext) + -> Option>>; + fn buffer_updated(&self, buffer_id: u64, operation: Operation, cx: &mut MutableAppContext); fn buffer_removed(&self, buffer_id: u64, cx: &mut MutableAppContext); @@ -437,6 +443,65 @@ impl Buffer { self.file.as_deref() } + pub fn format(&mut self, cx: &mut ModelContext) -> Task> { + let file = if let Some(file) = self.file.as_ref() { + file + } else { + return Task::ready(Err(anyhow!("buffer has no file"))); + }; + + if let Some(LanguageServerState { server, .. }) = self.language_server.as_ref() { + let server = server.clone(); + let abs_path = file.abs_path().unwrap(); + let version = self.version(); + cx.spawn(|this, mut cx| async move { + let edits = server + .request::(lsp::DocumentFormattingParams { + text_document: lsp::TextDocumentIdentifier::new( + lsp::Url::from_file_path(&abs_path).unwrap(), + ), + options: Default::default(), + work_done_progress_params: Default::default(), + }) + .await?; + + if let Some(edits) = edits { + this.update(&mut cx, |this, cx| { + if this.version == version { + for edit in &edits { + let range = range_from_lsp(edit.range); + if this.clip_point_utf16(range.start, Bias::Left) != range.start + || this.clip_point_utf16(range.end, Bias::Left) != range.end + { + return Err(anyhow!( + "invalid formatting edits received from language server" + )); + } + } + + for edit in edits.into_iter().rev() { + this.edit([range_from_lsp(edit.range)], edit.new_text, cx); + } + Ok(()) + } else { + Err(anyhow!("buffer edited since starting to format")) + } + }) + } else { + Ok(()) + } + }) + } else { + let format = file.format_remote(self.remote_id(), cx.as_mut()); + cx.spawn(|_, _| async move { + if let Some(format) = format { + format.await?; + } + Ok(()) + }) + } + } + pub fn save( &mut self, cx: &mut ModelContext, diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index 9f7f9f75ac4d6b190210b940b2cec422308d6685..769bcbe69c03de41a4e61a417707a7c63dff9f62 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -15,7 +15,7 @@ use highlight_map::HighlightMap; use lazy_static::lazy_static; use parking_lot::Mutex; use serde::Deserialize; -use std::{path::Path, str, sync::Arc}; +use std::{ops::Range, path::Path, str, sync::Arc}; use theme::SyntaxTheme; use tree_sitter::{self, Query}; pub use tree_sitter::{Parser, Tree}; @@ -33,6 +33,10 @@ lazy_static! { )); } +pub trait ToPointUtf16 { + fn to_point_utf16(self) -> PointUtf16; +} + #[derive(Default, Deserialize)] pub struct LanguageConfig { pub name: String, @@ -244,3 +248,15 @@ impl LanguageServerConfig { ) } } + +impl ToPointUtf16 for lsp::Position { + fn to_point_utf16(self) -> PointUtf16 { + PointUtf16::new(self.line, self.character) + } +} + +pub fn range_from_lsp(range: lsp::Range) -> Range { + let start = PointUtf16::new(range.start.line, range.start.character); + let end = PointUtf16::new(range.end.line, range.end.character); + start..end +} diff --git a/crates/lsp/src/lsp.rs b/crates/lsp/src/lsp.rs index c3d264e8a99f227156378c07e25a4dca726204fa..6d975e8e9fa87fd06a6112ca5694937fcdb09bf5 100644 --- a/crates/lsp/src/lsp.rs +++ b/crates/lsp/src/lsp.rs @@ -494,17 +494,25 @@ impl FakeLanguageServer { } pub async fn receive_request(&mut self) -> (RequestId, T::Params) { - self.receive().await; - let request = serde_json::from_slice::>(&self.buffer).unwrap(); - assert_eq!(request.method, T::METHOD); - assert_eq!(request.jsonrpc, JSON_RPC_VERSION); - ( - RequestId { - id: request.id, - _type: std::marker::PhantomData, - }, - request.params, - ) + loop { + self.receive().await; + if let Ok(request) = serde_json::from_slice::>(&self.buffer) { + assert_eq!(request.method, T::METHOD); + assert_eq!(request.jsonrpc, JSON_RPC_VERSION); + return ( + RequestId { + id: request.id, + _type: std::marker::PhantomData, + }, + request.params, + ); + } else { + println!( + "skipping message in fake language server {:?}", + std::str::from_utf8(&self.buffer) + ); + } + } } pub async fn receive_notification(&mut self) -> T::Params { diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index 5f0c966f6f675a748cd6f936f4a7688e68f99ff2..af7a3d5939238c9afc1e47c6d6ddc6405dfb3b89 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -308,6 +308,7 @@ impl Project { client.subscribe_to_entity(remote_id, cx, Self::handle_update_buffer), client.subscribe_to_entity(remote_id, cx, Self::handle_save_buffer), client.subscribe_to_entity(remote_id, cx, Self::handle_buffer_saved), + client.subscribe_to_entity(remote_id, cx, Self::handle_format_buffer), ]); } } @@ -808,6 +809,21 @@ impl Project { Ok(()) } + pub fn handle_format_buffer( + &mut self, + envelope: TypedEnvelope, + rpc: Arc, + cx: &mut ModelContext, + ) -> Result<()> { + let worktree_id = WorktreeId::from_proto(envelope.payload.worktree_id); + if let Some(worktree) = self.worktree_for_id(worktree_id, cx) { + worktree.update(cx, |worktree, cx| { + worktree.handle_format_buffer(envelope, rpc, cx) + })?; + } + Ok(()) + } + pub fn handle_open_buffer( &mut self, envelope: TypedEnvelope, diff --git a/crates/project/src/worktree.rs b/crates/project/src/worktree.rs index 7c16cf5f414238243a82dc5795b8444374261a2b..3d924b331072b238433186dee594e69c25c5b8a2 100644 --- a/crates/project/src/worktree.rs +++ b/crates/project/src/worktree.rs @@ -15,8 +15,8 @@ use gpui::{ Task, UpgradeModelHandle, WeakModelHandle, }; use language::{ - Buffer, Diagnostic, DiagnosticEntry, DiagnosticSeverity, File as _, Language, LanguageRegistry, - Operation, PointUtf16, Rope, + range_from_lsp, Buffer, Diagnostic, DiagnosticEntry, DiagnosticSeverity, File as _, Language, + LanguageRegistry, Operation, PointUtf16, Rope, }; use lazy_static::lazy_static; use lsp::LanguageServer; @@ -34,7 +34,7 @@ use std::{ ffi::{OsStr, OsString}, fmt, future::Future, - ops::{Deref, Range}, + ops::Deref, path::{Path, PathBuf}, sync::{ atomic::{AtomicUsize, Ordering::SeqCst}, @@ -580,6 +580,49 @@ impl Worktree { Ok(()) } + pub fn handle_format_buffer( + &mut self, + envelope: TypedEnvelope, + rpc: Arc, + cx: &mut ModelContext, + ) -> Result<()> { + let sender_id = envelope.original_sender_id()?; + let this = self.as_local().unwrap(); + let buffer = this + .shared_buffers + .get(&sender_id) + .and_then(|shared_buffers| shared_buffers.get(&envelope.payload.buffer_id).cloned()) + .ok_or_else(|| anyhow!("unknown buffer id {}", envelope.payload.buffer_id))?; + + let receipt = envelope.receipt(); + cx.spawn(|_, mut cx| async move { + let format = buffer.update(&mut cx, |buffer, cx| buffer.format(cx)).await; + // We spawn here in order to enqueue the sending of `Ack` *after* transmission of edits + // associated with formatting. + cx.spawn(|_| async move { + dbg!("responding"); + match format { + Ok(()) => rpc.respond(receipt, proto::Ack {}).await?, + Err(error) => { + rpc.respond_with_error( + receipt, + proto::Error { + message: error.to_string(), + }, + ) + .await? + } + } + Ok::<_, anyhow::Error>(()) + }) + .await + .log_err(); + }) + .detach(); + + Ok(()) + } + fn poll_snapshot(&mut self, cx: &mut ModelContext) { match self { Self::Local(worktree) => { @@ -880,6 +923,7 @@ impl Worktree { )), } { cx.spawn(|worktree, mut cx| async move { + dbg!(&operation); if let Err(error) = rpc .request(proto::UpdateBuffer { project_id, @@ -2259,6 +2303,27 @@ impl language::File for File { ) } + fn format_remote( + &self, + buffer_id: u64, + cx: &mut MutableAppContext, + ) -> Option>> { + let worktree = self.worktree.read(cx); + let worktree_id = worktree.id().to_proto(); + let worktree = worktree.as_remote()?; + let rpc = worktree.client.clone(); + let project_id = worktree.project_id; + Some(cx.foreground().spawn(async move { + rpc.request(proto::FormatBuffer { + project_id, + worktree_id, + buffer_id, + }) + .await?; + Ok(()) + })) + } + fn buffer_updated(&self, buffer_id: u64, operation: Operation, cx: &mut MutableAppContext) { self.worktree.update(cx, |worktree, cx| { worktree.send_buffer_update(buffer_id, operation, cx); @@ -3180,22 +3245,6 @@ impl<'a> TryFrom<(&'a CharBag, proto::Entry)> for Entry { } } -trait ToPointUtf16 { - fn to_point_utf16(self) -> PointUtf16; -} - -impl ToPointUtf16 for lsp::Position { - fn to_point_utf16(self) -> PointUtf16 { - PointUtf16::new(self.line, self.character) - } -} - -fn range_from_lsp(range: lsp::Range) -> Range { - let start = PointUtf16::new(range.start.line, range.start.character); - let end = PointUtf16::new(range.end.line, range.end.character); - start..end -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/rpc/proto/zed.proto b/crates/rpc/proto/zed.proto index f6300c44952a0935d3eb1f3d90ccc15062b8d9e9..47774bf360f92dcac041b561540c40d20483245b 100644 --- a/crates/rpc/proto/zed.proto +++ b/crates/rpc/proto/zed.proto @@ -35,22 +35,23 @@ message Envelope { UpdateBuffer update_buffer = 27; SaveBuffer save_buffer = 28; BufferSaved buffer_saved = 29; - - GetChannels get_channels = 30; - GetChannelsResponse get_channels_response = 31; - JoinChannel join_channel = 32; - JoinChannelResponse join_channel_response = 33; - LeaveChannel leave_channel = 34; - SendChannelMessage send_channel_message = 35; - SendChannelMessageResponse send_channel_message_response = 36; - ChannelMessageSent channel_message_sent = 37; - GetChannelMessages get_channel_messages = 38; - GetChannelMessagesResponse get_channel_messages_response = 39; - - UpdateContacts update_contacts = 40; - - GetUsers get_users = 41; - GetUsersResponse get_users_response = 42; + FormatBuffer format_buffer = 30; + + GetChannels get_channels = 31; + GetChannelsResponse get_channels_response = 32; + JoinChannel join_channel = 33; + JoinChannelResponse join_channel_response = 34; + LeaveChannel leave_channel = 35; + SendChannelMessage send_channel_message = 36; + SendChannelMessageResponse send_channel_message_response = 37; + ChannelMessageSent channel_message_sent = 38; + GetChannelMessages get_channel_messages = 39; + GetChannelMessagesResponse get_channel_messages_response = 40; + + UpdateContacts update_contacts = 41; + + GetUsers get_users = 42; + GetUsersResponse get_users_response = 43; } } @@ -168,6 +169,12 @@ message BufferSaved { Timestamp mtime = 5; } +message FormatBuffer { + uint64 project_id = 1; + uint64 worktree_id = 2; + uint64 buffer_id = 3; +} + message UpdateDiagnosticSummary { uint64 project_id = 1; uint64 worktree_id = 2; diff --git a/crates/rpc/src/peer.rs b/crates/rpc/src/peer.rs index 091a0c1555dc9a6cca7019780ae95894d4488c9d..9b6d8c8786a21078ad76452775944b7dc15db457 100644 --- a/crates/rpc/src/peer.rs +++ b/crates/rpc/src/peer.rs @@ -398,7 +398,7 @@ mod tests { proto::OpenBufferResponse { buffer: Some(proto::Buffer { id: 101, - visible_text: "path/one content".to_string(), + content: "path/one content".to_string(), ..Default::default() }), } @@ -419,7 +419,7 @@ mod tests { proto::OpenBufferResponse { buffer: Some(proto::Buffer { id: 102, - visible_text: "path/two content".to_string(), + content: "path/two content".to_string(), ..Default::default() }), } @@ -448,7 +448,7 @@ mod tests { proto::OpenBufferResponse { buffer: Some(proto::Buffer { id: 101, - visible_text: "path/one content".to_string(), + content: "path/one content".to_string(), ..Default::default() }), } @@ -458,7 +458,7 @@ mod tests { proto::OpenBufferResponse { buffer: Some(proto::Buffer { id: 102, - visible_text: "path/two content".to_string(), + content: "path/two content".to_string(), ..Default::default() }), } diff --git a/crates/rpc/src/proto.rs b/crates/rpc/src/proto.rs index 91abc2523d3c026567b0a3c4f83fa00115ab3cdd..8860bc5f0549b0a4341e5fe85526c299a5f1fa24 100644 --- a/crates/rpc/src/proto.rs +++ b/crates/rpc/src/proto.rs @@ -128,6 +128,7 @@ messages!( DiskBasedDiagnosticsUpdated, DiskBasedDiagnosticsUpdating, Error, + FormatBuffer, GetChannelMessages, GetChannelMessagesResponse, GetChannels, @@ -162,6 +163,7 @@ messages!( ); request_messages!( + (FormatBuffer, Ack), (GetChannelMessages, GetChannelMessagesResponse), (GetChannels, GetChannelsResponse), (GetUsers, GetUsersResponse), @@ -185,6 +187,7 @@ entity_messages!( CloseBuffer, DiskBasedDiagnosticsUpdated, DiskBasedDiagnosticsUpdating, + FormatBuffer, JoinProject, LeaveProject, OpenBuffer, diff --git a/crates/server/src/rpc.rs b/crates/server/src/rpc.rs index 76698c3a19171cb34a76b196db5dbb95ed805f7b..220c01ef1a67f3d79a7b31f7da962570349bbc12 100644 --- a/crates/server/src/rpc.rs +++ b/crates/server/src/rpc.rs @@ -79,6 +79,7 @@ impl Server { .add_handler(Server::update_buffer) .add_handler(Server::buffer_saved) .add_handler(Server::save_buffer) + .add_handler(Server::format_buffer) .add_handler(Server::get_channels) .add_handler(Server::get_users) .add_handler(Server::join_channel) @@ -660,6 +661,30 @@ impl Server { Ok(()) } + async fn format_buffer( + self: Arc, + request: TypedEnvelope, + ) -> tide::Result<()> { + let host; + { + let state = self.state(); + let project = state + .read_project(request.payload.project_id, request.sender_id) + .ok_or_else(|| anyhow!(NO_SUCH_PROJECT))?; + host = project.host_connection_id; + } + + let sender = request.sender_id; + let receipt = request.receipt(); + let response = self + .peer + .forward_request(sender, host, request.payload.clone()) + .await?; + self.peer.respond(receipt, response).await?; + + Ok(()) + } + async fn update_buffer( self: Arc, request: TypedEnvelope, @@ -2001,6 +2026,111 @@ mod tests { }); } + #[gpui::test(iterations = 1, seed = 2)] + async fn test_formatting_buffer(mut cx_a: TestAppContext, mut cx_b: TestAppContext) { + cx_a.foreground().forbid_parking(); + let mut lang_registry = Arc::new(LanguageRegistry::new()); + let fs = Arc::new(FakeFs::new()); + + // Set up a fake language server. + let (language_server_config, mut fake_language_server) = + LanguageServerConfig::fake(cx_a.background()).await; + Arc::get_mut(&mut lang_registry) + .unwrap() + .add(Arc::new(Language::new( + LanguageConfig { + name: "Rust".to_string(), + path_suffixes: vec!["rs".to_string()], + language_server: Some(language_server_config), + ..Default::default() + }, + Some(tree_sitter_rust::language()), + ))); + + // Connect to a server as 2 clients. + let mut server = TestServer::start(cx_a.foreground()).await; + let client_a = server.create_client(&mut cx_a, "user_a").await; + let client_b = server.create_client(&mut cx_b, "user_b").await; + + // Share a project as client A + fs.insert_tree( + "/a", + json!({ + ".zed.toml": r#"collaborators = ["user_b"]"#, + "a.rs": "let one = two", + }), + ) + .await; + let project_a = cx_a.update(|cx| { + Project::local( + client_a.clone(), + client_a.user_store.clone(), + lang_registry.clone(), + fs.clone(), + cx, + ) + }); + let worktree_a = project_a + .update(&mut cx_a, |p, cx| p.add_local_worktree("/a", cx)) + .await + .unwrap(); + worktree_a + .read_with(&cx_a, |tree, _| tree.as_local().unwrap().scan_complete()) + .await; + let project_id = project_a + .update(&mut cx_a, |project, _| project.next_remote_id()) + .await; + project_a + .update(&mut cx_a, |project, cx| project.share(cx)) + .await + .unwrap(); + + // Join the worktree as client B. + let project_b = Project::remote( + project_id, + client_b.clone(), + client_b.user_store.clone(), + lang_registry.clone(), + fs.clone(), + &mut cx_b.to_async(), + ) + .await + .unwrap(); + + // Open the file to be formatted on client B. + let worktree_b = project_b.update(&mut cx_b, |p, _| p.worktrees()[0].clone()); + let buffer_b = cx_b + .background() + .spawn(worktree_b.update(&mut cx_b, |worktree, cx| worktree.open_buffer("a.rs", cx))) + .await + .unwrap(); + + let format = buffer_b.update(&mut cx_b, |buffer, cx| buffer.format(cx)); + let (request_id, _) = fake_language_server + .receive_request::() + .await; + fake_language_server + .respond( + request_id, + Some(vec![ + lsp::TextEdit { + range: lsp::Range::new(lsp::Position::new(0, 4), lsp::Position::new(0, 4)), + new_text: "h".to_string(), + }, + lsp::TextEdit { + range: lsp::Range::new(lsp::Position::new(0, 7), lsp::Position::new(0, 7)), + new_text: "y".to_string(), + }, + ]), + ) + .await; + format.await.unwrap(); + assert_eq!( + buffer_b.read_with(&cx_b, |buffer, _| buffer.text()), + "let honey = two" + ); + } + #[gpui::test] async fn test_basic_chat(mut cx_a: TestAppContext, mut cx_b: TestAppContext) { cx_a.foreground().forbid_parking(); diff --git a/crates/text/src/rope.rs b/crates/text/src/rope.rs index 89ce278de1a65a2f53658b188cf6452c7d973960..d9c900d8bc40541128a619d8cd24219122e6b04b 100644 --- a/crates/text/src/rope.rs +++ b/crates/text/src/rope.rs @@ -593,6 +593,12 @@ impl Chunk { if ch == '\n' { point.row += 1; + if point.row > target.row { + panic!( + "point {:?} is beyond the end of a line with length {}", + target, point.column + ); + } point.column = 0; } else { point.column += ch.len_utf16() as u32; From 9e4b118214e2e72f271fcf5c6836435cfad3d54c Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 12 Jan 2022 18:02:41 +0100 Subject: [PATCH 2/5] Use synchronous locks for `Peer` state We hold these locks for a short amount of time anyway, and using an async lock could cause parallel sends to happen in an order different than the order in which `send`/`request` was called. Co-Authored-By: Nathan Sobo --- crates/client/src/client.rs | 10 +++---- crates/client/src/test.rs | 4 +-- crates/rpc/src/peer.rs | 60 ++++++++++++++++--------------------- crates/server/src/rpc.rs | 6 ++-- 4 files changed, 36 insertions(+), 44 deletions(-) diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index f8512b6550c0c45f2f9e5ef6baf2562e4703638c..2a6cb1aefe825f9fa1cacdee6208973b27369948 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -661,9 +661,9 @@ impl Client { }) } - pub async fn disconnect(self: &Arc, cx: &AsyncAppContext) -> Result<()> { + pub fn disconnect(self: &Arc, cx: &AsyncAppContext) -> Result<()> { let conn_id = self.connection_id()?; - self.peer.disconnect(conn_id).await; + self.peer.disconnect(conn_id); self.set_status(Status::SignedOut, cx); Ok(()) } @@ -764,7 +764,7 @@ mod tests { let ping = server.receive::().await.unwrap(); server.respond(ping.receipt(), proto::Ack {}).await; - client.disconnect(&cx.to_async()).await.unwrap(); + client.disconnect(&cx.to_async()).unwrap(); assert!(server.receive::().await.is_err()); } @@ -783,7 +783,7 @@ mod tests { assert_eq!(server.auth_count(), 1); server.forbid_connections(); - server.disconnect().await; + server.disconnect(); while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {} server.allow_connections(); @@ -792,7 +792,7 @@ mod tests { assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting server.forbid_connections(); - server.disconnect().await; + server.disconnect(); while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {} // Clear cached credentials after authentication fails diff --git a/crates/client/src/test.rs b/crates/client/src/test.rs index a40d1ee3a776a93a91e513abc6f41bb78b0a0097..6339d025f1c4b87741250fae99599952748960d9 100644 --- a/crates/client/src/test.rs +++ b/crates/client/src/test.rs @@ -72,8 +72,8 @@ impl FakeServer { server } - pub async fn disconnect(&self) { - self.peer.disconnect(self.connection_id()).await; + pub fn disconnect(&self) { + self.peer.disconnect(self.connection_id()); self.connection_id.lock().take(); self.incoming.lock().take(); } diff --git a/crates/rpc/src/peer.rs b/crates/rpc/src/peer.rs index 9b6d8c8786a21078ad76452775944b7dc15db457..848ae4440281ed3c4a013134b374da06c9778a19 100644 --- a/crates/rpc/src/peer.rs +++ b/crates/rpc/src/peer.rs @@ -1,8 +1,8 @@ use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage}; use super::Connection; use anyhow::{anyhow, Context, Result}; -use async_lock::{Mutex, RwLock}; use futures::FutureExt as _; +use parking_lot::{Mutex, RwLock}; use postage::{ mpsc, prelude::{Sink as _, Stream as _}, @@ -133,7 +133,7 @@ impl Peer { incoming = read_message => match incoming { Ok(incoming) => { if let Some(responding_to) = incoming.responding_to { - let channel = response_channels.lock().await.as_mut().unwrap().remove(&responding_to); + let channel = response_channels.lock().as_mut().unwrap().remove(&responding_to); if let Some(mut tx) = channel { tx.send(incoming).await.ok(); } else { @@ -169,25 +169,24 @@ impl Peer { } }; - response_channels.lock().await.take(); - this.connections.write().await.remove(&connection_id); + response_channels.lock().take(); + this.connections.write().remove(&connection_id); result }; self.connections .write() - .await .insert(connection_id, connection_state); (connection_id, handle_io, incoming_rx) } - pub async fn disconnect(&self, connection_id: ConnectionId) { - self.connections.write().await.remove(&connection_id); + pub fn disconnect(&self, connection_id: ConnectionId) { + self.connections.write().remove(&connection_id); } - pub async fn reset(&self) { - self.connections.write().await.clear(); + pub fn reset(&self) { + self.connections.write().clear(); } pub fn request( @@ -216,12 +215,11 @@ impl Peer { let this = self.clone(); let (tx, mut rx) = mpsc::channel(1); async move { - let mut connection = this.connection_state(receiver_id).await?; + let mut connection = this.connection_state(receiver_id)?; let message_id = connection.next_message_id.fetch_add(1, SeqCst); connection .response_channels .lock() - .await .as_mut() .ok_or_else(|| anyhow!("connection was closed"))? .insert(message_id, tx); @@ -250,7 +248,7 @@ impl Peer { ) -> impl Future> { let this = self.clone(); async move { - let mut connection = this.connection_state(receiver_id).await?; + let mut connection = this.connection_state(receiver_id)?; let message_id = connection .next_message_id .fetch_add(1, atomic::Ordering::SeqCst); @@ -270,7 +268,7 @@ impl Peer { ) -> impl Future> { let this = self.clone(); async move { - let mut connection = this.connection_state(receiver_id).await?; + let mut connection = this.connection_state(receiver_id)?; let message_id = connection .next_message_id .fetch_add(1, atomic::Ordering::SeqCst); @@ -289,7 +287,7 @@ impl Peer { ) -> impl Future> { let this = self.clone(); async move { - let mut connection = this.connection_state(receipt.sender_id).await?; + let mut connection = this.connection_state(receipt.sender_id)?; let message_id = connection .next_message_id .fetch_add(1, atomic::Ordering::SeqCst); @@ -308,7 +306,7 @@ impl Peer { ) -> impl Future> { let this = self.clone(); async move { - let mut connection = this.connection_state(receipt.sender_id).await?; + let mut connection = this.connection_state(receipt.sender_id)?; let message_id = connection .next_message_id .fetch_add(1, atomic::Ordering::SeqCst); @@ -320,18 +318,12 @@ impl Peer { } } - fn connection_state( - self: &Arc, - connection_id: ConnectionId, - ) -> impl Future> { - let this = self.clone(); - async move { - let connections = this.connections.read().await; - let connection = connections - .get(&connection_id) - .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?; - Ok(connection.clone()) - } + fn connection_state(&self, connection_id: ConnectionId) -> Result { + let connections = self.connections.read(); + let connection = connections + .get(&connection_id) + .ok_or_else(|| anyhow!("no such connection: {}", connection_id))?; + Ok(connection.clone()) } } @@ -398,7 +390,7 @@ mod tests { proto::OpenBufferResponse { buffer: Some(proto::Buffer { id: 101, - content: "path/one content".to_string(), + visible_text: "path/one content".to_string(), ..Default::default() }), } @@ -419,14 +411,14 @@ mod tests { proto::OpenBufferResponse { buffer: Some(proto::Buffer { id: 102, - content: "path/two content".to_string(), + visible_text: "path/two content".to_string(), ..Default::default() }), } ); - client1.disconnect(client1_conn_id).await; - client2.disconnect(client1_conn_id).await; + client1.disconnect(client1_conn_id); + client2.disconnect(client1_conn_id); async fn handle_messages( mut messages: mpsc::Receiver>, @@ -448,7 +440,7 @@ mod tests { proto::OpenBufferResponse { buffer: Some(proto::Buffer { id: 101, - content: "path/one content".to_string(), + visible_text: "path/one content".to_string(), ..Default::default() }), } @@ -458,7 +450,7 @@ mod tests { proto::OpenBufferResponse { buffer: Some(proto::Buffer { id: 102, - content: "path/two content".to_string(), + visible_text: "path/two content".to_string(), ..Default::default() }), } @@ -502,7 +494,7 @@ mod tests { }) .detach(); - client.disconnect(connection_id).await; + client.disconnect(connection_id); io_ended_rx.recv().await; messages_ended_rx.recv().await; diff --git a/crates/server/src/rpc.rs b/crates/server/src/rpc.rs index 220c01ef1a67f3d79a7b31f7da962570349bbc12..8248dfa103df5cf316aa690571756b1a1f341592 100644 --- a/crates/server/src/rpc.rs +++ b/crates/server/src/rpc.rs @@ -174,7 +174,7 @@ impl Server { } async fn sign_out(self: &mut Arc, connection_id: ConnectionId) -> tide::Result<()> { - self.peer.disconnect(connection_id).await; + self.peer.disconnect(connection_id); let removed_connection = self.state_mut().remove_connection(connection_id)?; for (project_id, project) in removed_connection.hosted_projects { @@ -1801,7 +1801,7 @@ mod tests { .await; // Drop client B's connection and ensure client A observes client B leaving the worktree. - client_b.disconnect(&cx_b.to_async()).await.unwrap(); + client_b.disconnect(&cx_b.to_async()).unwrap(); project_a .condition(&cx_a, |p, _| p.collaborators().len() == 0) .await; @@ -2833,7 +2833,7 @@ mod tests { impl Drop for TestServer { fn drop(&mut self) { - task::block_on(self.peer.reset()); + self.peer.reset(); } } From 8b53868f8aff8a92a8eccad37ce32c23613e85c5 Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 12 Jan 2022 18:26:00 +0100 Subject: [PATCH 3/5] Preserve the order of responses with respect to all other incoming messages Co-Authored-By: Nathan Sobo Co-Authored-By: Max Brunsfeld --- crates/client/src/client.rs | 23 +++++++------- crates/client/src/test.rs | 7 ++--- crates/project/src/worktree.rs | 2 -- crates/rpc/src/peer.rs | 55 +++++++++++++++++++++------------- crates/server/src/rpc.rs | 16 +++++----- 5 files changed, 57 insertions(+), 46 deletions(-) diff --git a/crates/client/src/client.rs b/crates/client/src/client.rs index 2a6cb1aefe825f9fa1cacdee6208973b27369948..e22cd7cba90feeb137fe3856770851c253084f1f 100644 --- a/crates/client/src/client.rs +++ b/crates/client/src/client.rs @@ -11,11 +11,12 @@ use async_tungstenite::tungstenite::{ error::Error as WebsocketError, http::{Request, StatusCode}, }; +use futures::StreamExt; use gpui::{action, AsyncAppContext, Entity, ModelContext, MutableAppContext, Task}; use http::HttpClient; use lazy_static::lazy_static; use parking_lot::RwLock; -use postage::{prelude::Stream, watch}; +use postage::watch; use rand::prelude::*; use rpc::proto::{AnyTypedEnvelope, EntityMessage, EnvelopedMessage, RequestMessage}; use std::{ @@ -436,7 +437,7 @@ impl Client { let mut cx = cx.clone(); let this = self.clone(); async move { - while let Some(message) = incoming.recv().await { + while let Some(message) = incoming.next().await { let mut state = this.state.write(); let payload_type_id = message.payload_type_id(); let entity_id = if let Some(extract_entity_id) = @@ -777,23 +778,23 @@ mod tests { let server = FakeServer::for_client(user_id, &mut client, &cx).await; let mut status = client.status(); assert!(matches!( - status.recv().await, + status.next().await, Some(Status::Connected { .. }) )); assert_eq!(server.auth_count(), 1); server.forbid_connections(); server.disconnect(); - while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {} + while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {} server.allow_connections(); cx.foreground().advance_clock(Duration::from_secs(10)); - while !matches!(status.recv().await, Some(Status::Connected { .. })) {} + while !matches!(status.next().await, Some(Status::Connected { .. })) {} assert_eq!(server.auth_count(), 1); // Client reused the cached credentials when reconnecting server.forbid_connections(); server.disconnect(); - while !matches!(status.recv().await, Some(Status::ReconnectionError { .. })) {} + while !matches!(status.next().await, Some(Status::ReconnectionError { .. })) {} // Clear cached credentials after authentication fails server.roll_access_token(); @@ -801,7 +802,7 @@ mod tests { cx.foreground().advance_clock(Duration::from_secs(10)); assert_eq!(server.auth_count(), 1); cx.foreground().advance_clock(Duration::from_secs(10)); - while !matches!(status.recv().await, Some(Status::Connected { .. })) {} + while !matches!(status.next().await, Some(Status::Connected { .. })) {} assert_eq!(server.auth_count(), 2); // Client re-authenticated due to an invalid token } @@ -861,8 +862,8 @@ mod tests { server.send(proto::UnshareProject { project_id: 1 }).await; server.send(proto::UnshareProject { project_id: 2 }).await; - done_rx1.recv().await.unwrap(); - done_rx2.recv().await.unwrap(); + done_rx1.next().await.unwrap(); + done_rx2.next().await.unwrap(); } #[gpui::test] @@ -890,7 +891,7 @@ mod tests { }) }); server.send(proto::Ping {}).await; - done_rx2.recv().await.unwrap(); + done_rx2.next().await.unwrap(); } #[gpui::test] @@ -914,7 +915,7 @@ mod tests { )); }); server.send(proto::Ping {}).await; - done_rx.recv().await.unwrap(); + done_rx.next().await.unwrap(); } struct Model { diff --git a/crates/client/src/test.rs b/crates/client/src/test.rs index 6339d025f1c4b87741250fae99599952748960d9..1630a454b79296e27ec9eb1545aeb8f438b010e6 100644 --- a/crates/client/src/test.rs +++ b/crates/client/src/test.rs @@ -1,10 +1,9 @@ use super::Client; use super::*; use crate::http::{HttpClient, Request, Response, ServerResponse}; -use futures::{future::BoxFuture, Future}; +use futures::{future::BoxFuture, stream::BoxStream, Future, StreamExt}; use gpui::{ModelHandle, TestAppContext}; use parking_lot::Mutex; -use postage::{mpsc, prelude::Stream}; use rpc::{proto, ConnectionId, Peer, Receipt, TypedEnvelope}; use std::fmt; use std::sync::atomic::Ordering::SeqCst; @@ -15,7 +14,7 @@ use std::sync::{ pub struct FakeServer { peer: Arc, - incoming: Mutex>>>, + incoming: Mutex>>>, connection_id: Mutex>, forbid_connections: AtomicBool, auth_count: AtomicUsize, @@ -129,7 +128,7 @@ impl FakeServer { .lock() .as_mut() .expect("not connected") - .recv() + .next() .await .ok_or_else(|| anyhow!("other half hung up"))?; let type_name = message.payload_type_name(); diff --git a/crates/project/src/worktree.rs b/crates/project/src/worktree.rs index 3d924b331072b238433186dee594e69c25c5b8a2..a9ee86268d90e84132b44f85574b58a3d506783f 100644 --- a/crates/project/src/worktree.rs +++ b/crates/project/src/worktree.rs @@ -600,7 +600,6 @@ impl Worktree { // We spawn here in order to enqueue the sending of `Ack` *after* transmission of edits // associated with formatting. cx.spawn(|_| async move { - dbg!("responding"); match format { Ok(()) => rpc.respond(receipt, proto::Ack {}).await?, Err(error) => { @@ -923,7 +922,6 @@ impl Worktree { )), } { cx.spawn(|worktree, mut cx| async move { - dbg!(&operation); if let Err(error) = rpc .request(proto::UpdateBuffer { project_id, diff --git a/crates/rpc/src/peer.rs b/crates/rpc/src/peer.rs index 848ae4440281ed3c4a013134b374da06c9778a19..2f1ac2a249a5bb4727dc080c1ef8f1f3f6a57a13 100644 --- a/crates/rpc/src/peer.rs +++ b/crates/rpc/src/peer.rs @@ -1,7 +1,8 @@ use super::proto::{self, AnyTypedEnvelope, EnvelopedMessage, MessageStream, RequestMessage}; use super::Connection; use anyhow::{anyhow, Context, Result}; -use futures::FutureExt as _; +use futures::stream::BoxStream; +use futures::{FutureExt as _, StreamExt}; use parking_lot::{Mutex, RwLock}; use postage::{ mpsc, @@ -109,7 +110,7 @@ impl Peer { ) -> ( ConnectionId, impl Future> + Send, - mpsc::Receiver>, + BoxStream<'static, Box>, ) { let connection_id = ConnectionId(self.next_connection_id.fetch_add(1, SeqCst)); let (mut incoming_tx, incoming_rx) = mpsc::channel(64); @@ -132,23 +133,9 @@ impl Peer { futures::select_biased! { incoming = read_message => match incoming { Ok(incoming) => { - if let Some(responding_to) = incoming.responding_to { - let channel = response_channels.lock().as_mut().unwrap().remove(&responding_to); - if let Some(mut tx) = channel { - tx.send(incoming).await.ok(); - } else { - log::warn!("received RPC response to unknown request {}", responding_to); - } - } else { - if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) { - if incoming_tx.send(envelope).await.is_err() { - break 'outer Ok(()) - } - } else { - log::error!("unable to construct a typed envelope"); - } + if incoming_tx.send(incoming).await.is_err() { + break 'outer Ok(()); } - break; } Err(error) => { @@ -174,11 +161,38 @@ impl Peer { result }; + let response_channels = connection_state.response_channels.clone(); self.connections .write() .insert(connection_id, connection_state); - (connection_id, handle_io, incoming_rx) + let incoming_rx = incoming_rx.filter_map(move |incoming| { + let response_channels = response_channels.clone(); + async move { + if let Some(responding_to) = incoming.responding_to { + let channel = response_channels + .lock() + .as_mut() + .unwrap() + .remove(&responding_to); + if let Some(mut tx) = channel { + tx.send(incoming).await.ok(); + } else { + log::warn!("received RPC response to unknown request {}", responding_to); + } + + None + } else { + if let Some(envelope) = proto::build_typed_envelope(connection_id, incoming) { + Some(envelope) + } else { + log::error!("unable to construct a typed envelope"); + None + } + } + } + }); + (connection_id, handle_io, incoming_rx.boxed()) } pub fn disconnect(&self, connection_id: ConnectionId) { @@ -332,7 +346,6 @@ mod tests { use super::*; use crate::TypedEnvelope; use async_tungstenite::tungstenite::Message as WebSocketMessage; - use futures::StreamExt as _; #[test] fn test_request_response() { @@ -421,7 +434,7 @@ mod tests { client2.disconnect(client1_conn_id); async fn handle_messages( - mut messages: mpsc::Receiver>, + mut messages: BoxStream<'static, Box>, peer: Arc, ) -> Result<()> { while let Some(envelope) = messages.next().await { diff --git a/crates/server/src/rpc.rs b/crates/server/src/rpc.rs index 8248dfa103df5cf316aa690571756b1a1f341592..40a3f956bbc4bd69decd71a8e8e2955e359646ec 100644 --- a/crates/server/src/rpc.rs +++ b/crates/server/src/rpc.rs @@ -9,9 +9,9 @@ use anyhow::anyhow; use async_std::task; use async_tungstenite::{tungstenite::protocol::Role, WebSocketStream}; use collections::{HashMap, HashSet}; -use futures::{future::BoxFuture, FutureExt}; +use futures::{future::BoxFuture, FutureExt, StreamExt}; use parking_lot::{RwLock, RwLockReadGuard, RwLockWriteGuard}; -use postage::{mpsc, prelude::Sink as _, prelude::Stream as _}; +use postage::{mpsc, prelude::Sink as _}; use rpc::{ proto::{self, AnyTypedEnvelope, EnvelopedMessage}, Connection, ConnectionId, Peer, TypedEnvelope, @@ -133,7 +133,7 @@ impl Server { let handle_io = handle_io.fuse(); futures::pin_mut!(handle_io); loop { - let next_message = incoming_rx.recv().fuse(); + let next_message = incoming_rx.next().fuse(); futures::pin_mut!(next_message); futures::select_biased! { message = next_message => { @@ -2026,7 +2026,7 @@ mod tests { }); } - #[gpui::test(iterations = 1, seed = 2)] + #[gpui::test] async fn test_formatting_buffer(mut cx_a: TestAppContext, mut cx_b: TestAppContext) { cx_a.foreground().forbid_parking(); let mut lang_registry = Arc::new(LanguageRegistry::new()); @@ -2425,7 +2425,7 @@ mod tests { server.forbid_connections(); server.disconnect_client(client_b.current_user_id(&cx_b)); while !matches!( - status_b.recv().await, + status_b.next().await, Some(client::Status::ReconnectionError { .. }) ) {} @@ -2769,11 +2769,11 @@ mod tests { .await .unwrap(); - let peer_id = PeerId(connection_id_rx.recv().await.unwrap().0); + let peer_id = PeerId(connection_id_rx.next().await.unwrap().0); let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http, cx)); let mut authed_user = user_store.read_with(cx, |user_store, _| user_store.watch_current_user()); - while authed_user.recv().await.unwrap().is_none() {} + while authed_user.next().await.unwrap().is_none() {} TestClient { client, @@ -2822,7 +2822,7 @@ mod tests { async_std::future::timeout(Duration::from_millis(500), async { while !(predicate)(&*self.server.store.read()) { self.foreground.start_waiting(); - self.notifications.recv().await; + self.notifications.next().await; self.foreground.finish_waiting(); } }) From 66694b4c9a19e9630afbaa7c75122efdfe826f2e Mon Sep 17 00:00:00 2001 From: Antonio Scandurra Date: Wed, 12 Jan 2022 18:43:23 +0100 Subject: [PATCH 4/5] Fix failing tests Co-Authored-By: Max Brunsfeld --- crates/language/src/buffer.rs | 2 +- crates/rpc/src/peer.rs | 22 +++++++++++----------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index e6b593f70d63f3c9411336a1faffca3e53f898bd..07d4017c09eb14f64a7b9463b29247f436154403 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -6,7 +6,7 @@ pub use crate::{ }; use crate::{ diagnostic_set::{DiagnosticEntry, DiagnosticGroup}, - range_from_lsp, ToPointUtf16, + range_from_lsp, }; use anyhow::{anyhow, Result}; use clock::ReplicaId; diff --git a/crates/rpc/src/peer.rs b/crates/rpc/src/peer.rs index 2f1ac2a249a5bb4727dc080c1ef8f1f3f6a57a13..30d754e97de831e22514b60b89e5b473422628ee 100644 --- a/crates/rpc/src/peer.rs +++ b/crates/rpc/src/peer.rs @@ -170,11 +170,7 @@ impl Peer { let response_channels = response_channels.clone(); async move { if let Some(responding_to) = incoming.responding_to { - let channel = response_channels - .lock() - .as_mut() - .unwrap() - .remove(&responding_to); + let channel = response_channels.lock().as_mut()?.remove(&responding_to); if let Some(mut tx) = channel { tx.send(incoming).await.ok(); } else { @@ -356,21 +352,25 @@ mod tests { let client2 = Peer::new(); let (client1_to_server_conn, server_to_client_1_conn, _) = Connection::in_memory(); - let (client1_conn_id, io_task1, _) = + let (client1_conn_id, io_task1, client1_incoming) = client1.add_connection(client1_to_server_conn).await; - let (_, io_task2, incoming1) = server.add_connection(server_to_client_1_conn).await; + let (_, io_task2, server_incoming1) = + server.add_connection(server_to_client_1_conn).await; let (client2_to_server_conn, server_to_client_2_conn, _) = Connection::in_memory(); - let (client2_conn_id, io_task3, _) = + let (client2_conn_id, io_task3, client2_incoming) = client2.add_connection(client2_to_server_conn).await; - let (_, io_task4, incoming2) = server.add_connection(server_to_client_2_conn).await; + let (_, io_task4, server_incoming2) = + server.add_connection(server_to_client_2_conn).await; smol::spawn(io_task1).detach(); smol::spawn(io_task2).detach(); smol::spawn(io_task3).detach(); smol::spawn(io_task4).detach(); - smol::spawn(handle_messages(incoming1, server.clone())).detach(); - smol::spawn(handle_messages(incoming2, server.clone())).detach(); + smol::spawn(handle_messages(server_incoming1, server.clone())).detach(); + smol::spawn(handle_messages(client1_incoming, client1.clone())).detach(); + smol::spawn(handle_messages(server_incoming2, server.clone())).detach(); + smol::spawn(handle_messages(client2_incoming, client2.clone())).detach(); assert_eq!( client1 From 30225678c0e431360ba8c4342adf7c4c0be72dc0 Mon Sep 17 00:00:00 2001 From: Nathan Sobo Date: Wed, 12 Jan 2022 11:19:17 -0700 Subject: [PATCH 5/5] Test ordering of responses with respect to uni-directional messages Co-Authored-By: Max Brunsfeld Co-Authored-By: Antonio Scandurra --- Cargo.lock | 1 + crates/rpc/Cargo.toml | 1 + crates/rpc/src/peer.rs | 442 +++++++++++++++++++++++++---------------- 3 files changed, 278 insertions(+), 166 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index dda7116de2ffb6dbf8351ccb3ec0752c9e9f3838..8c3174d68d40df5c364bee8327318476b8dd0fa7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3837,6 +3837,7 @@ dependencies = [ "async-tungstenite", "base64 0.13.0", "futures", + "gpui", "log", "parking_lot", "postage", diff --git a/crates/rpc/Cargo.toml b/crates/rpc/Cargo.toml index f16d7f39c2a1f4f82beba4b6a334402d781d61e9..4be612eec77ae902db19b04cd04dcd3b19adf527 100644 --- a/crates/rpc/Cargo.toml +++ b/crates/rpc/Cargo.toml @@ -30,5 +30,6 @@ zstd = "0.9" prost-build = "0.8" [dev-dependencies] +gpui = { path = "../gpui", features = ["test-support"] } smol = "1.2.5" tempdir = "0.3.7" diff --git a/crates/rpc/src/peer.rs b/crates/rpc/src/peer.rs index 30d754e97de831e22514b60b89e5b473422628ee..ce9680173311ecb42dfef999c6fef7dee09e606f 100644 --- a/crates/rpc/src/peer.rs +++ b/crates/rpc/src/peer.rs @@ -342,201 +342,311 @@ mod tests { use super::*; use crate::TypedEnvelope; use async_tungstenite::tungstenite::Message as WebSocketMessage; + use gpui::TestAppContext; + + #[gpui::test(iterations = 10)] + async fn test_request_response(cx: TestAppContext) { + let executor = cx.foreground(); + + // create 2 clients connected to 1 server + let server = Peer::new(); + let client1 = Peer::new(); + let client2 = Peer::new(); + + let (client1_to_server_conn, server_to_client_1_conn, _) = Connection::in_memory(); + let (client1_conn_id, io_task1, client1_incoming) = + client1.add_connection(client1_to_server_conn).await; + let (_, io_task2, server_incoming1) = server.add_connection(server_to_client_1_conn).await; + + let (client2_to_server_conn, server_to_client_2_conn, _) = Connection::in_memory(); + let (client2_conn_id, io_task3, client2_incoming) = + client2.add_connection(client2_to_server_conn).await; + let (_, io_task4, server_incoming2) = server.add_connection(server_to_client_2_conn).await; + + executor.spawn(io_task1).detach(); + executor.spawn(io_task2).detach(); + executor.spawn(io_task3).detach(); + executor.spawn(io_task4).detach(); + executor + .spawn(handle_messages(server_incoming1, server.clone())) + .detach(); + executor + .spawn(handle_messages(client1_incoming, client1.clone())) + .detach(); + executor + .spawn(handle_messages(server_incoming2, server.clone())) + .detach(); + executor + .spawn(handle_messages(client2_incoming, client2.clone())) + .detach(); - #[test] - fn test_request_response() { - smol::block_on(async move { - // create 2 clients connected to 1 server - let server = Peer::new(); - let client1 = Peer::new(); - let client2 = Peer::new(); - - let (client1_to_server_conn, server_to_client_1_conn, _) = Connection::in_memory(); - let (client1_conn_id, io_task1, client1_incoming) = - client1.add_connection(client1_to_server_conn).await; - let (_, io_task2, server_incoming1) = - server.add_connection(server_to_client_1_conn).await; - - let (client2_to_server_conn, server_to_client_2_conn, _) = Connection::in_memory(); - let (client2_conn_id, io_task3, client2_incoming) = - client2.add_connection(client2_to_server_conn).await; - let (_, io_task4, server_incoming2) = - server.add_connection(server_to_client_2_conn).await; - - smol::spawn(io_task1).detach(); - smol::spawn(io_task2).detach(); - smol::spawn(io_task3).detach(); - smol::spawn(io_task4).detach(); - smol::spawn(handle_messages(server_incoming1, server.clone())).detach(); - smol::spawn(handle_messages(client1_incoming, client1.clone())).detach(); - smol::spawn(handle_messages(server_incoming2, server.clone())).detach(); - smol::spawn(handle_messages(client2_incoming, client2.clone())).detach(); - - assert_eq!( - client1 - .request(client1_conn_id, proto::Ping {},) - .await - .unwrap(), - proto::Ack {} - ); + assert_eq!( + client1 + .request(client1_conn_id, proto::Ping {},) + .await + .unwrap(), + proto::Ack {} + ); - assert_eq!( - client2 - .request(client2_conn_id, proto::Ping {},) + assert_eq!( + client2 + .request(client2_conn_id, proto::Ping {},) + .await + .unwrap(), + proto::Ack {} + ); + + assert_eq!( + client1 + .request( + client1_conn_id, + proto::OpenBuffer { + project_id: 0, + worktree_id: 1, + path: "path/one".to_string(), + }, + ) + .await + .unwrap(), + proto::OpenBufferResponse { + buffer: Some(proto::Buffer { + id: 101, + visible_text: "path/one content".to_string(), + ..Default::default() + }), + } + ); + + assert_eq!( + client2 + .request( + client2_conn_id, + proto::OpenBuffer { + project_id: 0, + worktree_id: 2, + path: "path/two".to_string(), + }, + ) + .await + .unwrap(), + proto::OpenBufferResponse { + buffer: Some(proto::Buffer { + id: 102, + visible_text: "path/two content".to_string(), + ..Default::default() + }), + } + ); + + client1.disconnect(client1_conn_id); + client2.disconnect(client1_conn_id); + + async fn handle_messages( + mut messages: BoxStream<'static, Box>, + peer: Arc, + ) -> Result<()> { + while let Some(envelope) = messages.next().await { + let envelope = envelope.into_any(); + if let Some(envelope) = envelope.downcast_ref::>() { + let receipt = envelope.receipt(); + peer.respond(receipt, proto::Ack {}).await? + } else if let Some(envelope) = + envelope.downcast_ref::>() + { + let message = &envelope.payload; + let receipt = envelope.receipt(); + let response = match message.path.as_str() { + "path/one" => { + assert_eq!(message.worktree_id, 1); + proto::OpenBufferResponse { + buffer: Some(proto::Buffer { + id: 101, + visible_text: "path/one content".to_string(), + ..Default::default() + }), + } + } + "path/two" => { + assert_eq!(message.worktree_id, 2); + proto::OpenBufferResponse { + buffer: Some(proto::Buffer { + id: 102, + visible_text: "path/two content".to_string(), + ..Default::default() + }), + } + } + _ => { + panic!("unexpected path {}", message.path); + } + }; + + peer.respond(receipt, response).await? + } else { + panic!("unknown message type"); + } + } + + Ok(()) + } + } + + #[gpui::test(iterations = 10)] + async fn test_order_of_response_and_incoming(cx: TestAppContext) { + let executor = cx.foreground(); + let server = Peer::new(); + let client = Peer::new(); + + let (client_to_server_conn, server_to_client_conn, _) = Connection::in_memory(); + let (client_to_server_conn_id, io_task1, mut client_incoming) = + client.add_connection(client_to_server_conn).await; + let (server_to_client_conn_id, io_task2, mut server_incoming) = + server.add_connection(server_to_client_conn).await; + + executor.spawn(io_task1).detach(); + executor.spawn(io_task2).detach(); + + executor + .spawn(async move { + let request = server_incoming + .next() .await - .unwrap(), - proto::Ack {} - ); - - assert_eq!( - client1 - .request( - client1_conn_id, - proto::OpenBuffer { - project_id: 0, - worktree_id: 1, - path: "path/one".to_string(), + .unwrap() + .into_any() + .downcast::>() + .unwrap(); + + server + .send( + server_to_client_conn_id, + proto::Error { + message: "message 1".to_string(), }, ) .await - .unwrap(), - proto::OpenBufferResponse { - buffer: Some(proto::Buffer { - id: 101, - visible_text: "path/one content".to_string(), - ..Default::default() - }), - } - ); - - assert_eq!( - client2 - .request( - client2_conn_id, - proto::OpenBuffer { - project_id: 0, - worktree_id: 2, - path: "path/two".to_string(), + .unwrap(); + server + .send( + server_to_client_conn_id, + proto::Error { + message: "message 2".to_string(), }, ) .await - .unwrap(), - proto::OpenBufferResponse { - buffer: Some(proto::Buffer { - id: 102, - visible_text: "path/two content".to_string(), - ..Default::default() - }), - } - ); - - client1.disconnect(client1_conn_id); - client2.disconnect(client1_conn_id); - - async fn handle_messages( - mut messages: BoxStream<'static, Box>, - peer: Arc, - ) -> Result<()> { - while let Some(envelope) = messages.next().await { - let envelope = envelope.into_any(); - if let Some(envelope) = envelope.downcast_ref::>() { - let receipt = envelope.receipt(); - peer.respond(receipt, proto::Ack {}).await? - } else if let Some(envelope) = - envelope.downcast_ref::>() - { - let message = &envelope.payload; - let receipt = envelope.receipt(); - let response = match message.path.as_str() { - "path/one" => { - assert_eq!(message.worktree_id, 1); - proto::OpenBufferResponse { - buffer: Some(proto::Buffer { - id: 101, - visible_text: "path/one content".to_string(), - ..Default::default() - }), - } - } - "path/two" => { - assert_eq!(message.worktree_id, 2); - proto::OpenBufferResponse { - buffer: Some(proto::Buffer { - id: 102, - visible_text: "path/two content".to_string(), - ..Default::default() - }), - } - } - _ => { - panic!("unexpected path {}", message.path); - } - }; + .unwrap(); + server + .respond(request.receipt(), proto::Ack {}) + .await + .unwrap(); - peer.respond(receipt, response).await? - } else { - panic!("unknown message type"); - } - } + // Prevent the connection from being dropped + server_incoming.next().await; + }) + .detach(); + + let events = Arc::new(Mutex::new(Vec::new())); - Ok(()) + let response = client.request(client_to_server_conn_id, proto::Ping {}); + let response_task = executor.spawn({ + let events = events.clone(); + async move { + response.await.unwrap(); + events.lock().push("response".to_string()); } }); + + executor + .spawn({ + let events = events.clone(); + async move { + let incoming1 = client_incoming + .next() + .await + .unwrap() + .into_any() + .downcast::>() + .unwrap(); + events.lock().push(incoming1.payload.message); + let incoming2 = client_incoming + .next() + .await + .unwrap() + .into_any() + .downcast::>() + .unwrap(); + events.lock().push(incoming2.payload.message); + + // Prevent the connection from being dropped + client_incoming.next().await; + } + }) + .detach(); + + response_task.await; + assert_eq!( + &*events.lock(), + &[ + "message 1".to_string(), + "message 2".to_string(), + "response".to_string() + ] + ); } - #[test] - fn test_disconnect() { - smol::block_on(async move { - let (client_conn, mut server_conn, _) = Connection::in_memory(); + #[gpui::test(iterations = 10)] + async fn test_disconnect(cx: TestAppContext) { + let executor = cx.foreground(); + + let (client_conn, mut server_conn, _) = Connection::in_memory(); - let client = Peer::new(); - let (connection_id, io_handler, mut incoming) = - client.add_connection(client_conn).await; + let client = Peer::new(); + let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await; - let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel(); - smol::spawn(async move { + let (mut io_ended_tx, mut io_ended_rx) = postage::barrier::channel(); + executor + .spawn(async move { io_handler.await.ok(); io_ended_tx.send(()).await.unwrap(); }) .detach(); - let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel(); - smol::spawn(async move { + let (mut messages_ended_tx, mut messages_ended_rx) = postage::barrier::channel(); + executor + .spawn(async move { incoming.next().await; messages_ended_tx.send(()).await.unwrap(); }) .detach(); - client.disconnect(connection_id); + client.disconnect(connection_id); - io_ended_rx.recv().await; - messages_ended_rx.recv().await; - assert!(server_conn - .send(WebSocketMessage::Binary(vec![])) - .await - .is_err()); - }); + io_ended_rx.recv().await; + messages_ended_rx.recv().await; + assert!(server_conn + .send(WebSocketMessage::Binary(vec![])) + .await + .is_err()); } - #[test] - fn test_io_error() { - smol::block_on(async move { - let (client_conn, mut server_conn, _) = Connection::in_memory(); - - let client = Peer::new(); - let (connection_id, io_handler, mut incoming) = - client.add_connection(client_conn).await; - smol::spawn(io_handler).detach(); - smol::spawn(async move { incoming.next().await }).detach(); - - let response = smol::spawn(client.request(connection_id, proto::Ping {})); - let _request = server_conn.rx.next().await.unwrap().unwrap(); - - drop(server_conn); - assert_eq!( - response.await.unwrap_err().to_string(), - "connection was closed" - ); - }); + #[gpui::test(iterations = 10)] + async fn test_io_error(cx: TestAppContext) { + let executor = cx.foreground(); + let (client_conn, mut server_conn, _) = Connection::in_memory(); + + let client = Peer::new(); + let (connection_id, io_handler, mut incoming) = client.add_connection(client_conn).await; + executor.spawn(io_handler).detach(); + executor + .spawn(async move { incoming.next().await }) + .detach(); + + let response = executor.spawn(client.request(connection_id, proto::Ping {})); + let _request = server_conn.rx.next().await.unwrap().unwrap(); + + drop(server_conn); + assert_eq!( + response.await.unwrap_err().to_string(), + "connection was closed" + ); } }