Allow FakeLanguageServer handlers to handle multiple requests

Max Brunsfeld and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

crates/editor/src/editor.rs |  10 
crates/lsp/src/lsp.rs       |  88 +++++++--------
crates/server/src/rpc.rs    | 219 ++++++++++++++++++++++----------------
3 files changed, 171 insertions(+), 146 deletions(-)

Detailed changes

crates/editor/src/editor.rs 🔗

@@ -5447,8 +5447,8 @@ mod tests {
     use super::*;
     use language::LanguageConfig;
     use lsp::FakeLanguageServer;
-    use postage::prelude::Stream;
     use project::{FakeFs, ProjectPath};
+    use smol::stream::StreamExt;
     use std::{cell::RefCell, rc::Rc, time::Instant};
     use text::Point;
     use unindent::Unindent;
@@ -7911,7 +7911,7 @@ mod tests {
                 );
                 Some(lsp::CompletionResponse::Array(
                     completions
-                        .into_iter()
+                        .iter()
                         .map(|(range, new_text)| lsp::CompletionItem {
                             label: new_text.to_string(),
                             text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit {
@@ -7926,7 +7926,7 @@ mod tests {
                         .collect(),
                 ))
             })
-            .recv()
+            .next()
             .await;
         }
 
@@ -7936,7 +7936,7 @@ mod tests {
         ) {
             fake.handle_request::<lsp::request::ResolveCompletionItem, _>(move |_| {
                 lsp::CompletionItem {
-                    additional_text_edits: edit.map(|(range, new_text)| {
+                    additional_text_edits: edit.clone().map(|(range, new_text)| {
                         vec![lsp::TextEdit::new(
                             lsp::Range::new(
                                 lsp::Position::new(range.start.row, range.start.column),
@@ -7948,7 +7948,7 @@ mod tests {
                     ..Default::default()
                 }
             })
-            .recv()
+            .next()
             .await;
         }
     }

crates/lsp/src/lsp.rs 🔗

@@ -1,5 +1,5 @@
 use anyhow::{anyhow, Context, Result};
-use futures::{io::BufWriter, AsyncRead, AsyncWrite};
+use futures::{channel::mpsc, io::BufWriter, AsyncRead, AsyncWrite};
 use gpui::{executor, Task};
 use parking_lot::{Mutex, RwLock};
 use postage::{barrier, oneshot, prelude::Stream, sink::Sink, watch};
@@ -481,16 +481,10 @@ impl Drop for Subscription {
 
 #[cfg(any(test, feature = "test-support"))]
 pub struct FakeLanguageServer {
-    handlers: Arc<
-        Mutex<
-            HashMap<
-                &'static str,
-                Box<dyn Send + FnOnce(usize, &[u8]) -> (Vec<u8>, barrier::Sender)>,
-            >,
-        >,
-    >,
-    outgoing_tx: channel::Sender<Vec<u8>>,
-    incoming_rx: channel::Receiver<Vec<u8>>,
+    handlers:
+        Arc<Mutex<HashMap<&'static str, Box<dyn Send + Sync + FnMut(usize, &[u8]) -> Vec<u8>>>>>,
+    outgoing_tx: mpsc::UnboundedSender<Vec<u8>>,
+    incoming_rx: mpsc::UnboundedReceiver<Vec<u8>>,
 }
 
 #[cfg(any(test, feature = "test-support"))]
@@ -508,8 +502,9 @@ impl LanguageServer {
 
         let mut fake = FakeLanguageServer::new(executor.clone(), stdin_reader, stdout_writer);
         fake.handle_request::<request::Initialize, _>({
+            let capabilities = capabilities.clone();
             move |_| InitializeResult {
-                capabilities,
+                capabilities: capabilities.clone(),
                 ..Default::default()
             }
         });
@@ -530,8 +525,8 @@ impl FakeLanguageServer {
     ) -> Self {
         use futures::StreamExt as _;
 
-        let (incoming_tx, incoming_rx) = channel::unbounded();
-        let (outgoing_tx, mut outgoing_rx) = channel::unbounded();
+        let (incoming_tx, incoming_rx) = mpsc::unbounded();
+        let (outgoing_tx, mut outgoing_rx) = mpsc::unbounded();
         let this = Self {
             outgoing_tx: outgoing_tx.clone(),
             incoming_rx,
@@ -545,36 +540,31 @@ impl FakeLanguageServer {
                 let mut buffer = Vec::new();
                 let mut stdin = smol::io::BufReader::new(stdin);
                 while Self::receive(&mut stdin, &mut buffer).await.is_ok() {
-                    if let Ok(request) = serde_json::from_slice::<AnyRequest>(&mut buffer) {
+                    if let Ok(request) = serde_json::from_slice::<AnyRequest>(&buffer) {
                         assert_eq!(request.jsonrpc, JSON_RPC_VERSION);
 
-                        let handler = handlers.lock().remove(request.method);
-                        if let Some(handler) = handler {
-                            let (response, sent) =
-                                handler(request.id, request.params.get().as_bytes());
+                        if let Some(handler) = handlers.lock().get_mut(request.method) {
+                            let response = handler(request.id, request.params.get().as_bytes());
                             log::debug!("handled lsp request. method:{}", request.method);
-                            outgoing_tx.send(response).await.unwrap();
-                            drop(sent);
+                            outgoing_tx.unbounded_send(response)?;
                         } else {
                             log::debug!("unhandled lsp request. method:{}", request.method);
-                            outgoing_tx
-                                .send(
-                                    serde_json::to_vec(&AnyResponse {
-                                        id: request.id,
-                                        error: Some(Error {
-                                            message: "no handler".to_string(),
-                                        }),
-                                        result: None,
-                                    })
-                                    .unwrap(),
-                                )
-                                .await
-                                .unwrap();
+                            outgoing_tx.unbounded_send(
+                                serde_json::to_vec(&AnyResponse {
+                                    id: request.id,
+                                    error: Some(Error {
+                                        message: "no handler".to_string(),
+                                    }),
+                                    result: None,
+                                })
+                                .unwrap(),
+                            )?;
                         }
                     } else {
-                        incoming_tx.send(buffer.clone()).await.unwrap();
+                        incoming_tx.unbounded_send(buffer.clone())?;
                     }
                 }
+                Ok::<_, anyhow::Error>(())
             })
             .detach();
 
@@ -598,7 +588,7 @@ impl FakeLanguageServer {
             params,
         })
         .unwrap();
-        self.outgoing_tx.send(message).await.unwrap();
+        self.outgoing_tx.unbounded_send(message).unwrap();
     }
 
     pub async fn receive_notification<T: notification::Notification>(&mut self) -> T::Params {
@@ -618,15 +608,15 @@ impl FakeLanguageServer {
         }
     }
 
-    pub fn handle_request<T, F>(&mut self, handler: F) -> barrier::Receiver
+    pub fn handle_request<T, F>(&mut self, mut handler: F) -> mpsc::UnboundedReceiver<()>
     where
         T: 'static + request::Request,
-        F: 'static + Send + FnOnce(T::Params) -> T::Result,
+        F: 'static + Send + Sync + FnMut(T::Params) -> T::Result,
     {
-        let (responded_tx, responded_rx) = barrier::channel();
-        let prev_handler = self.handlers.lock().insert(
+        let (responded_tx, responded_rx) = mpsc::unbounded();
+        self.handlers.lock().insert(
             T::METHOD,
-            Box::new(|id, params| {
+            Box::new(move |id, params| {
                 let result = handler(serde_json::from_slice::<T::Params>(params).unwrap());
                 let result = serde_json::to_string(&result).unwrap();
                 let result = serde_json::from_str::<&RawValue>(&result).unwrap();
@@ -635,18 +625,20 @@ impl FakeLanguageServer {
                     error: None,
                     result: Some(result),
                 };
-                (serde_json::to_vec(&response).unwrap(), responded_tx)
+                responded_tx.unbounded_send(()).ok();
+                serde_json::to_vec(&response).unwrap()
             }),
         );
-        if prev_handler.is_some() {
-            panic!(
-                "registered a new handler for LSP method '{}' before the previous handler was called",
-                T::METHOD
-            );
-        }
         responded_rx
     }
 
+    pub fn remove_request_handler<T>(&mut self)
+    where
+        T: 'static + request::Request,
+    {
+        self.handlers.lock().remove(T::METHOD);
+    }
+
     pub async fn start_progress(&mut self, token: impl Into<String>) {
         self.notify::<notification::Progress>(ProgressParams {
             token: NumberOrString::String(token.into()),

crates/server/src/rpc.rs 🔗

@@ -1093,6 +1093,7 @@ mod tests {
     };
     use ::rpc::Peer;
     use collections::BTreeMap;
+    use futures::channel::mpsc::UnboundedReceiver;
     use gpui::{executor, ModelHandle, TestAppContext};
     use parking_lot::Mutex;
     use postage::{mpsc, watch};
@@ -1126,7 +1127,7 @@ mod tests {
             tree_sitter_rust, AnchorRangeExt, Diagnostic, DiagnosticEntry, Language,
             LanguageConfig, LanguageRegistry, LanguageServerConfig, Point,
         },
-        lsp,
+        lsp::{self, FakeLanguageServer},
         project::{DiagnosticSummary, Project, ProjectPath},
         workspace::{Workspace, WorkspaceParams},
     };
@@ -2320,45 +2321,51 @@ mod tests {
 
         // Receive a completion request as the host's language server.
         // Return some completions from the host's language server.
-        fake_language_server.handle_request::<lsp::request::Completion, _>(|params| {
-            assert_eq!(
-                params.text_document_position.text_document.uri,
-                lsp::Url::from_file_path("/a/main.rs").unwrap(),
-            );
-            assert_eq!(
-                params.text_document_position.position,
-                lsp::Position::new(0, 14),
-            );
+        cx_a.foreground().start_waiting();
+        fake_language_server
+            .handle_request::<lsp::request::Completion, _>(|params| {
+                assert_eq!(
+                    params.text_document_position.text_document.uri,
+                    lsp::Url::from_file_path("/a/main.rs").unwrap(),
+                );
+                assert_eq!(
+                    params.text_document_position.position,
+                    lsp::Position::new(0, 14),
+                );
 
-            Some(lsp::CompletionResponse::Array(vec![
-                lsp::CompletionItem {
-                    label: "first_method(…)".into(),
-                    detail: Some("fn(&mut self, B) -> C".into()),
-                    text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit {
-                        new_text: "first_method($1)".to_string(),
-                        range: lsp::Range::new(
-                            lsp::Position::new(0, 14),
-                            lsp::Position::new(0, 14),
-                        ),
-                    })),
-                    insert_text_format: Some(lsp::InsertTextFormat::SNIPPET),
-                    ..Default::default()
-                },
-                lsp::CompletionItem {
-                    label: "second_method(…)".into(),
-                    detail: Some("fn(&mut self, C) -> D<E>".into()),
-                    text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit {
-                        new_text: "second_method()".to_string(),
-                        range: lsp::Range::new(
-                            lsp::Position::new(0, 14),
-                            lsp::Position::new(0, 14),
-                        ),
-                    })),
-                    insert_text_format: Some(lsp::InsertTextFormat::SNIPPET),
-                    ..Default::default()
-                },
-            ]))
-        });
+                Some(lsp::CompletionResponse::Array(vec![
+                    lsp::CompletionItem {
+                        label: "first_method(…)".into(),
+                        detail: Some("fn(&mut self, B) -> C".into()),
+                        text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit {
+                            new_text: "first_method($1)".to_string(),
+                            range: lsp::Range::new(
+                                lsp::Position::new(0, 14),
+                                lsp::Position::new(0, 14),
+                            ),
+                        })),
+                        insert_text_format: Some(lsp::InsertTextFormat::SNIPPET),
+                        ..Default::default()
+                    },
+                    lsp::CompletionItem {
+                        label: "second_method(…)".into(),
+                        detail: Some("fn(&mut self, C) -> D<E>".into()),
+                        text_edit: Some(lsp::CompletionTextEdit::Edit(lsp::TextEdit {
+                            new_text: "second_method()".to_string(),
+                            range: lsp::Range::new(
+                                lsp::Position::new(0, 14),
+                                lsp::Position::new(0, 14),
+                            ),
+                        })),
+                        insert_text_format: Some(lsp::InsertTextFormat::SNIPPET),
+                        ..Default::default()
+                    },
+                ]))
+            })
+            .next()
+            .await
+            .unwrap();
+        cx_a.foreground().finish_waiting();
 
         // Open the buffer on the host.
         let buffer_a = project_a
@@ -2896,58 +2903,62 @@ mod tests {
             editor.select_ranges([Point::new(1, 31)..Point::new(1, 31)], None, cx);
             cx.focus(&editor_b);
         });
-        fake_language_server.handle_request::<lsp::request::CodeActionRequest, _>(|params| {
-            assert_eq!(
-                params.text_document.uri,
-                lsp::Url::from_file_path("/a/main.rs").unwrap(),
-            );
-            assert_eq!(params.range.start, lsp::Position::new(1, 31));
-            assert_eq!(params.range.end, lsp::Position::new(1, 31));
-
-            Some(vec![lsp::CodeActionOrCommand::CodeAction(
-                lsp::CodeAction {
-                    title: "Inline into all callers".to_string(),
-                    edit: Some(lsp::WorkspaceEdit {
-                        changes: Some(
-                            [
-                                (
-                                    lsp::Url::from_file_path("/a/main.rs").unwrap(),
-                                    vec![lsp::TextEdit::new(
-                                        lsp::Range::new(
-                                            lsp::Position::new(1, 22),
-                                            lsp::Position::new(1, 34),
-                                        ),
-                                        "4".to_string(),
-                                    )],
-                                ),
-                                (
-                                    lsp::Url::from_file_path("/a/other.rs").unwrap(),
-                                    vec![lsp::TextEdit::new(
-                                        lsp::Range::new(
-                                            lsp::Position::new(0, 0),
-                                            lsp::Position::new(0, 27),
-                                        ),
-                                        "".to_string(),
-                                    )],
-                                ),
-                            ]
-                            .into_iter()
-                            .collect(),
-                        ),
-                        ..Default::default()
-                    }),
-                    data: Some(json!({
-                        "codeActionParams": {
-                            "range": {
-                                "start": {"line": 1, "column": 31},
-                                "end": {"line": 1, "column": 31},
+
+        fake_language_server
+            .handle_request::<lsp::request::CodeActionRequest, _>(|params| {
+                assert_eq!(
+                    params.text_document.uri,
+                    lsp::Url::from_file_path("/a/main.rs").unwrap(),
+                );
+                assert_eq!(params.range.start, lsp::Position::new(1, 31));
+                assert_eq!(params.range.end, lsp::Position::new(1, 31));
+
+                Some(vec![lsp::CodeActionOrCommand::CodeAction(
+                    lsp::CodeAction {
+                        title: "Inline into all callers".to_string(),
+                        edit: Some(lsp::WorkspaceEdit {
+                            changes: Some(
+                                [
+                                    (
+                                        lsp::Url::from_file_path("/a/main.rs").unwrap(),
+                                        vec![lsp::TextEdit::new(
+                                            lsp::Range::new(
+                                                lsp::Position::new(1, 22),
+                                                lsp::Position::new(1, 34),
+                                            ),
+                                            "4".to_string(),
+                                        )],
+                                    ),
+                                    (
+                                        lsp::Url::from_file_path("/a/other.rs").unwrap(),
+                                        vec![lsp::TextEdit::new(
+                                            lsp::Range::new(
+                                                lsp::Position::new(0, 0),
+                                                lsp::Position::new(0, 27),
+                                            ),
+                                            "".to_string(),
+                                        )],
+                                    ),
+                                ]
+                                .into_iter()
+                                .collect(),
+                            ),
+                            ..Default::default()
+                        }),
+                        data: Some(json!({
+                            "codeActionParams": {
+                                "range": {
+                                    "start": {"line": 1, "column": 31},
+                                    "end": {"line": 1, "column": 31},
+                                }
                             }
-                        }
-                    })),
-                    ..Default::default()
-                },
-            )])
-        });
+                        })),
+                        ..Default::default()
+                    },
+                )])
+            })
+            .next()
+            .await;
 
         // Toggle code actions and wait for them to display.
         editor_b.update(&mut cx_b, |editor, cx| {
@@ -2957,6 +2968,8 @@ mod tests {
             .condition(&cx_b, |editor, _| editor.context_menu_visible())
             .await;
 
+        fake_language_server.remove_request_handler::<lsp::request::CodeActionRequest>();
+
         // Confirming the code action will trigger a resolve request.
         let confirm_action = workspace_b
             .update(&mut cx_b, |workspace, cx| {
@@ -3594,7 +3607,24 @@ mod tests {
             .unwrap_or(10);
 
         let rng = Rc::new(RefCell::new(rng));
-        let lang_registry = Arc::new(LanguageRegistry::new());
+
+        let mut host_lang_registry = Arc::new(LanguageRegistry::new());
+        let guest_lang_registry = Arc::new(LanguageRegistry::new());
+
+        // Set up a fake language server.
+        let (language_server_config, fake_language_servers) = LanguageServerConfig::fake();
+        Arc::get_mut(&mut host_lang_registry)
+            .unwrap()
+            .add(Arc::new(Language::new(
+                LanguageConfig {
+                    name: "Rust".to_string(),
+                    path_suffixes: vec!["rs".to_string()],
+                    language_server: Some(language_server_config),
+                    ..Default::default()
+                },
+                Some(tree_sitter_rust::language()),
+            )));
+
         let fs = Arc::new(FakeFs::new(cx.background()));
         fs.insert_tree(
             "/_collab",
@@ -3622,7 +3652,7 @@ mod tests {
             Project::local(
                 host.client.clone(),
                 host.user_store.clone(),
-                lang_registry.clone(),
+                host_lang_registry.clone(),
                 fs.clone(),
                 cx,
             )
@@ -3647,6 +3677,7 @@ mod tests {
 
         clients.push(cx.foreground().spawn(host.simulate_host(
             host_project.clone(),
+            fake_language_servers,
             operations.clone(),
             max_operations,
             rng.clone(),
@@ -3676,7 +3707,7 @@ mod tests {
                     host_project_id,
                     guest.client.clone(),
                     guest.user_store.clone(),
-                    lang_registry.clone(),
+                    guest_lang_registry.clone(),
                     fs.clone(),
                     &mut guest_cx.to_async(),
                 )
@@ -3971,6 +4002,7 @@ mod tests {
         async fn simulate_host(
             mut self,
             project: ModelHandle<Project>,
+            fake_language_servers: UnboundedReceiver<FakeLanguageServer>,
             operations: Rc<Cell<usize>>,
             max_operations: usize,
             rng: Rc<RefCell<StdRng>>,
@@ -4055,6 +4087,7 @@ mod tests {
                             let letter = rng.borrow_mut().gen_range(b'a'..=b'z');
                             path.push(std::str::from_utf8(&[letter]).unwrap());
                         }
+                        path.set_extension("rs");
                         let parent_path = path.parent().unwrap();
 
                         log::info!("Host: creating file {:?}", path);