Add a fake lsp server

Max Brunsfeld and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

Cargo.lock                |  11 +
crates/lsp/Cargo.toml     |   6 
crates/lsp/src/lib.rs     | 259 +++++++++++++++++++++++++++++++++++++---
crates/project/Cargo.toml |   2 
4 files changed, 255 insertions(+), 23 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -328,6 +328,15 @@ dependencies = [
  "futures-lite",
 ]
 
+[[package]]
+name = "async-pipe"
+version = "0.1.3"
+source = "git+https://github.com/routerify/async-pipe-rs?rev=feeb77e83142a9ff837d0767652ae41bfc5d8e47#feeb77e83142a9ff837d0767652ae41bfc5d8e47"
+dependencies = [
+ "futures",
+ "log",
+]
+
 [[package]]
 name = "async-process"
 version = "1.0.2"
@@ -2954,6 +2963,7 @@ name = "lsp"
 version = "0.1.0"
 dependencies = [
  "anyhow",
+ "async-pipe",
  "futures",
  "gpui",
  "log",
@@ -2962,6 +2972,7 @@ dependencies = [
  "postage",
  "serde 1.0.125",
  "serde_json 1.0.64",
+ "simplelog",
  "smol",
  "unindent",
  "util",

crates/lsp/Cargo.toml 🔗

@@ -3,10 +3,14 @@ name = "lsp"
 version = "0.1.0"
 edition = "2018"
 
+[features]
+test-support = ["async-pipe"]
+
 [dependencies]
 gpui = { path = "../gpui" }
 util = { path = "../util" }
 anyhow = "1.0"
+async-pipe = { git = "https://github.com/routerify/async-pipe-rs", rev = "feeb77e83142a9ff837d0767652ae41bfc5d8e47", optional = true }
 futures = "0.3"
 log = "0.4"
 lsp-types = "0.91"
@@ -19,4 +23,6 @@ smol = "1.2"
 [dev-dependencies]
 gpui = { path = "../gpui", features = ["test-support"] }
 util = { path = "../util", features = ["test-support"] }
+async-pipe = { git = "https://github.com/routerify/async-pipe-rs", rev = "feeb77e83142a9ff837d0767652ae41bfc5d8e47" }
+simplelog = "0.9"
 unindent = "0.1.7"

crates/lsp/src/lib.rs 🔗

@@ -1,4 +1,5 @@
 use anyhow::{anyhow, Context, Result};
+use futures::{io::BufWriter, AsyncRead, AsyncWrite};
 use gpui::{executor, AppContext, Task};
 use parking_lot::{Mutex, RwLock};
 use postage::{barrier, oneshot, prelude::Stream, sink::Sink};
@@ -13,6 +14,7 @@ use std::{
     collections::HashMap,
     future::Future,
     io::Write,
+    marker::PhantomData,
     str::FromStr,
     sync::{
         atomic::{AtomicUsize, Ordering::SeqCst},
@@ -22,6 +24,8 @@ use std::{
 use std::{path::Path, process::Stdio};
 use util::TryFutureExt;
 
+pub use lsp_types::*;
+
 const JSON_RPC_VERSION: &'static str = "2.0";
 const CONTENT_LEN_HEADER: &'static str = "Content-Length: ";
 
@@ -43,16 +47,16 @@ pub struct Subscription {
     notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
 }
 
-#[derive(Serialize)]
-struct Request<T> {
-    jsonrpc: &'static str,
+#[derive(Serialize, Deserialize)]
+struct Request<'a, T> {
+    jsonrpc: &'a str,
     id: usize,
-    method: &'static str,
+    method: &'a str,
     params: T,
 }
 
-#[derive(Deserialize)]
-struct Response<'a> {
+#[derive(Serialize, Deserialize)]
+struct AnyResponse<'a> {
     id: usize,
     #[serde(default)]
     error: Option<Error>,
@@ -60,22 +64,24 @@ struct Response<'a> {
     result: &'a RawValue,
 }
 
-#[derive(Serialize)]
-struct OutboundNotification<T> {
-    jsonrpc: &'static str,
-    method: &'static str,
+#[derive(Serialize, Deserialize)]
+struct Notification<'a, T> {
+    #[serde(borrow)]
+    jsonrpc: &'a str,
+    #[serde(borrow)]
+    method: &'a str,
     params: T,
 }
 
 #[derive(Deserialize)]
-struct InboundNotification<'a> {
+struct AnyNotification<'a> {
     #[serde(borrow)]
     method: &'a str,
     #[serde(borrow)]
     params: &'a RawValue,
 }
 
-#[derive(Debug, Deserialize)]
+#[derive(Debug, Serialize, Deserialize)]
 struct Error {
     message: String,
 }
@@ -90,24 +96,46 @@ impl LanguageServer {
             let rust_analyzer_path = cx
                 .platform()
                 .path_for_resource(Some(&rust_analyzer_name), None)?;
-            Self::new(root_path, &rust_analyzer_path, cx.background())
+            Self::new(root_path, &rust_analyzer_path, &[], cx.background())
         } else {
-            Self::new(root_path, Path::new(&rust_analyzer_name), cx.background())
+            Self::new(
+                root_path,
+                Path::new(&rust_analyzer_name),
+                &[],
+                cx.background(),
+            )
         }
     }
 
     pub fn new(
         root_path: &Path,
         server_path: &Path,
+        server_args: &[&str],
         background: &executor::Background,
     ) -> Result<Arc<Self>> {
         let mut server = Command::new(server_path)
+            .args(server_args)
             .stdin(Stdio::piped())
             .stdout(Stdio::piped())
             .stderr(Stdio::inherit())
             .spawn()?;
-        let mut stdin = server.stdin.take().unwrap();
-        let mut stdout = BufReader::new(server.stdout.take().unwrap());
+        let stdin = server.stdin.take().unwrap();
+        let stdout = server.stdout.take().unwrap();
+        Self::new_internal(root_path, stdin, stdout, background)
+    }
+
+    fn new_internal<Stdin, Stdout>(
+        root_path: &Path,
+        stdin: Stdin,
+        stdout: Stdout,
+        background: &executor::Background,
+    ) -> Result<Arc<Self>>
+    where
+        Stdin: AsyncWrite + Unpin + Send + 'static,
+        Stdout: AsyncRead + Unpin + Send + 'static,
+    {
+        let mut stdin = BufWriter::new(stdin);
+        let mut stdout = BufReader::new(stdout);
         let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
         let notification_handlers = Arc::new(RwLock::new(HashMap::<_, NotificationHandler>::new()));
         let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::new()));
@@ -119,7 +147,6 @@ impl LanguageServer {
                     let mut buffer = Vec::new();
                     loop {
                         buffer.clear();
-
                         stdout.read_until(b'\n', &mut buffer).await?;
                         stdout.read_until(b'\n', &mut buffer).await?;
                         let message_len: usize = std::str::from_utf8(&buffer)?
@@ -131,7 +158,7 @@ impl LanguageServer {
                         buffer.resize(message_len, 0);
                         stdout.read_exact(&mut buffer).await?;
 
-                        if let Ok(InboundNotification { method, params }) =
+                        if let Ok(AnyNotification { method, params }) =
                             serde_json::from_slice(&buffer)
                         {
                             if let Some(handler) = notification_handlers.read().get(method) {
@@ -146,7 +173,7 @@ impl LanguageServer {
                                     .unwrap()
                                 );
                             }
-                        } else if let Ok(Response { id, error, result }) =
+                        } else if let Ok(AnyResponse { id, error, result }) =
                             serde_json::from_slice(&buffer)
                         {
                             if let Some(handler) = response_handlers.lock().remove(&id) {
@@ -179,6 +206,7 @@ impl LanguageServer {
                     stdin.write_all(&content_len_buffer).await?;
                     stdin.write_all("\r\n\r\n".as_bytes()).await?;
                     stdin.write_all(&message).await?;
+                    stdin.flush().await?;
                 }
             }
             .log_err(),
@@ -211,7 +239,8 @@ impl LanguageServer {
     }
 
     async fn init(self: Arc<Self>, root_uri: lsp_types::Url) -> Result<()> {
-        self.request_internal::<lsp_types::request::Initialize>(lsp_types::InitializeParams {
+        #[allow(deprecated)]
+        let params = lsp_types::InitializeParams {
             process_id: Default::default(),
             root_path: Default::default(),
             root_uri: Some(root_uri),
@@ -226,8 +255,10 @@ impl LanguageServer {
             workspace_folders: Default::default(),
             client_info: Default::default(),
             locale: Default::default(),
-        })
-        .await?;
+        };
+
+        self.request_internal::<lsp_types::request::Initialize>(params)
+            .await?;
         self.notify_internal::<lsp_types::notification::Initialized>(
             lsp_types::InitializedParams {},
         )
@@ -327,7 +358,7 @@ impl LanguageServer {
         self: &Arc<Self>,
         params: T::Params,
     ) -> impl Future<Output = Result<()>> {
-        let message = serde_json::to_vec(&OutboundNotification {
+        let message = serde_json::to_vec(&Notification {
             jsonrpc: JSON_RPC_VERSION,
             method: T::METHOD,
             params,
@@ -342,16 +373,136 @@ impl LanguageServer {
     }
 }
 
+impl Subscription {
+    pub fn detach(mut self) {
+        self.method = "";
+    }
+}
+
 impl Drop for Subscription {
     fn drop(&mut self) {
         self.notification_handlers.write().remove(self.method);
     }
 }
 
+#[cfg(any(test, feature = "test-support"))]
+pub struct FakeLanguageServer {
+    buffer: Vec<u8>,
+    stdin: smol::io::BufReader<async_pipe::PipeReader>,
+    stdout: smol::io::BufWriter<async_pipe::PipeWriter>,
+}
+
+#[cfg(any(test, feature = "test-support"))]
+pub struct RequestId<T> {
+    id: usize,
+    _type: std::marker::PhantomData<T>,
+}
+
+#[cfg(any(test, feature = "test-support"))]
+impl LanguageServer {
+    pub async fn fake(executor: &executor::Background) -> (Arc<Self>, FakeLanguageServer) {
+        let stdin = async_pipe::pipe();
+        let stdout = async_pipe::pipe();
+        (
+            Self::new_internal(Path::new("/"), stdin.0, stdout.1, executor).unwrap(),
+            FakeLanguageServer {
+                stdin: smol::io::BufReader::new(stdin.1),
+                stdout: smol::io::BufWriter::new(stdout.0),
+                buffer: Vec::new(),
+            },
+        )
+    }
+}
+
+#[cfg(any(test, feature = "test-support"))]
+impl FakeLanguageServer {
+    pub async fn notify<T: notification::Notification>(&mut self, params: T::Params) {
+        let message = serde_json::to_vec(&Notification {
+            jsonrpc: JSON_RPC_VERSION,
+            method: T::METHOD,
+            params,
+        })
+        .unwrap();
+        self.send(message).await;
+    }
+
+    pub async fn respond<'a, T: request::Request>(
+        &mut self,
+        request_id: RequestId<T>,
+        result: T::Result,
+    ) {
+        let result = serde_json::to_string(&result).unwrap();
+        let message = serde_json::to_vec(&AnyResponse {
+            id: request_id.id,
+            error: None,
+            result: &RawValue::from_string(result).unwrap(),
+        })
+        .unwrap();
+        self.send(message).await;
+    }
+
+    pub async fn receive_request<T: request::Request>(&mut self) -> (RequestId<T>, T::Params) {
+        self.receive().await;
+        let request = serde_json::from_slice::<Request<T::Params>>(&self.buffer).unwrap();
+        assert_eq!(request.method, T::METHOD);
+        assert_eq!(request.jsonrpc, JSON_RPC_VERSION);
+        (
+            RequestId {
+                id: request.id,
+                _type: PhantomData,
+            },
+            request.params,
+        )
+    }
+
+    pub async fn receive_notification<T: notification::Notification>(&mut self) -> T::Params {
+        self.receive().await;
+        let notification = serde_json::from_slice::<Notification<T::Params>>(&self.buffer).unwrap();
+        assert_eq!(notification.method, T::METHOD);
+        notification.params
+    }
+
+    async fn send(&mut self, message: Vec<u8>) {
+        self.stdout
+            .write_all(CONTENT_LEN_HEADER.as_bytes())
+            .await
+            .unwrap();
+        self.stdout
+            .write_all((format!("{}", message.len())).as_bytes())
+            .await
+            .unwrap();
+        self.stdout.write_all("\r\n\r\n".as_bytes()).await.unwrap();
+        self.stdout.write_all(&message).await.unwrap();
+        self.stdout.flush().await.unwrap();
+    }
+
+    async fn receive(&mut self) {
+        self.buffer.clear();
+        self.stdin
+            .read_until(b'\n', &mut self.buffer)
+            .await
+            .unwrap();
+        self.stdin
+            .read_until(b'\n', &mut self.buffer)
+            .await
+            .unwrap();
+        let message_len: usize = std::str::from_utf8(&self.buffer)
+            .unwrap()
+            .strip_prefix(CONTENT_LEN_HEADER)
+            .unwrap()
+            .trim_end()
+            .parse()
+            .unwrap();
+        self.buffer.resize(message_len, 0);
+        self.stdin.read_exact(&mut self.buffer).await.unwrap();
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
     use gpui::TestAppContext;
+    use simplelog::SimpleLogger;
     use unindent::Unindent;
     use util::test::temp_tree;
 
@@ -414,6 +565,68 @@ mod tests {
         );
     }
 
+    #[gpui::test]
+    async fn test_fake(cx: TestAppContext) {
+        SimpleLogger::init(log::LevelFilter::Info, Default::default()).unwrap();
+
+        let (server, mut fake) = LanguageServer::fake(&cx.background()).await;
+
+        let (message_tx, message_rx) = channel::unbounded();
+        let (diagnostics_tx, diagnostics_rx) = channel::unbounded();
+        server
+            .on_notification::<notification::ShowMessage, _>(move |params| {
+                message_tx.try_send(params).unwrap()
+            })
+            .detach();
+        server
+            .on_notification::<notification::PublishDiagnostics, _>(move |params| {
+                diagnostics_tx.try_send(params).unwrap()
+            })
+            .detach();
+
+        let (init_id, _) = fake.receive_request::<request::Initialize>().await;
+        fake.respond(init_id, InitializeResult::default()).await;
+        fake.receive_notification::<notification::Initialized>()
+            .await;
+
+        server
+            .notify::<notification::DidOpenTextDocument>(DidOpenTextDocumentParams {
+                text_document: TextDocumentItem::new(
+                    Url::from_str("file://a/b").unwrap(),
+                    "rust".to_string(),
+                    0,
+                    "".to_string(),
+                ),
+            })
+            .await
+            .unwrap();
+        assert_eq!(
+            fake.receive_notification::<notification::DidOpenTextDocument>()
+                .await
+                .text_document
+                .uri
+                .as_str(),
+            "file://a/b"
+        );
+
+        fake.notify::<notification::ShowMessage>(ShowMessageParams {
+            typ: MessageType::ERROR,
+            message: "ok".to_string(),
+        })
+        .await;
+        fake.notify::<notification::PublishDiagnostics>(PublishDiagnosticsParams {
+            uri: Url::from_str("file://b/c").unwrap(),
+            version: Some(5),
+            diagnostics: vec![],
+        })
+        .await;
+        assert_eq!(message_rx.recv().await.unwrap().message, "ok");
+        assert_eq!(
+            diagnostics_rx.recv().await.unwrap().uri.as_str(),
+            "file://b/c"
+        );
+    }
+
     impl LanguageServer {
         async fn next_idle_notification(self: &Arc<Self>) {
             let (tx, rx) = channel::unbounded();

crates/project/Cargo.toml 🔗

@@ -33,6 +33,8 @@ toml = "0.5"
 
 [dev-dependencies]
 client = { path = "../client", features = ["test-support"] }
+gpui = { path = "../gpui", features = ["test-support"] }
+lsp = { path = "../lsp", features = ["test-support"] }
 util = { path = "../util", features = ["test-support"] }
 rpc = { path = "../rpc", features = ["test-support"] }
 rand = "0.8.3"