Get a basic end-to-end test for rust-analyzer integration working

Antonio Scandurra , Nathan Sobo , and Max Brunsfeld created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>
Co-Authored-By: Max Brunsfeld <max@zed.dev>

Change summary

Cargo.lock                      |   2 
crates/lsp/Cargo.toml           |   3 
crates/lsp/build.rs             |   6 
crates/lsp/src/lib.rs           | 231 ++++++++++++++++++++++++++--------
crates/project/src/lib.rs       |  14 -
crates/project/src/worktree.rs  |   4 
crates/project_panel/src/lib.rs |   3 
crates/workspace/src/lib.rs     |   3 
script/bundle                   |   2 
9 files changed, 194 insertions(+), 74 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -2956,12 +2956,14 @@ dependencies = [
  "anyhow",
  "futures",
  "gpui",
+ "log",
  "lsp-types",
  "parking_lot",
  "postage",
  "serde 1.0.125",
  "serde_json 1.0.64",
  "smol",
+ "unindent",
  "util",
 ]
 

crates/lsp/Cargo.toml 🔗

@@ -8,6 +8,7 @@ gpui = { path = "../gpui" }
 util = { path = "../util" }
 anyhow = "1.0"
 futures = "0.3"
+log = "0.4"
 lsp-types = "0.91"
 parking_lot = "0.11"
 postage = { version = "0.4.1", features = ["futures-traits"] }
@@ -17,3 +18,5 @@ smol = "1.2"
 
 [dev-dependencies]
 gpui = { path = "../gpui", features = ["test-support"] }
+util = { path = "../util", features = ["test-support"] }
+unindent = "0.1.7"

crates/lsp/build.rs 🔗

@@ -2,9 +2,9 @@ use std::env;
 
 fn main() {
     let target = env::var("TARGET").unwrap();
-    println!("cargo:rustc-env=TARGET={}", target);
+    println!("cargo:rustc-env=ZED_TARGET={}", target);
 
-    if let Ok(bundled) = env::var("BUNDLE") {
-        println!("cargo:rustc-env=BUNDLE={}", bundled);
+    if let Ok(bundled) = env::var("ZED_BUNDLE") {
+        println!("cargo:rustc-env=ZED_BUNDLE={}", bundled);
     }
 }

crates/lsp/src/lib.rs 🔗

@@ -1,9 +1,9 @@
 use anyhow::{anyhow, Context, Result};
 use gpui::{executor, AppContext, Task};
-use parking_lot::Mutex;
-use postage::{barrier, prelude::Stream};
+use parking_lot::{Mutex, RwLock};
+use postage::{barrier, oneshot, prelude::Stream, sink::Sink};
 use serde::{Deserialize, Serialize};
-use serde_json::value::RawValue;
+use serde_json::{json, value::RawValue};
 use smol::{
     channel,
     io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader},
@@ -24,16 +24,23 @@ use util::TryFutureExt;
 const JSON_RPC_VERSION: &'static str = "2.0";
 const CONTENT_LEN_HEADER: &'static str = "Content-Length: ";
 
+type NotificationHandler = Box<dyn Send + Sync + Fn(&str)>;
+type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
+
 pub struct LanguageServer {
     next_id: AtomicUsize,
     outbound_tx: channel::Sender<Vec<u8>>,
+    notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
     response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
     _input_task: Task<Option<()>>,
     _output_task: Task<Option<()>>,
     initialized: barrier::Receiver,
 }
 
-type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
+pub struct Subscription {
+    method: &'static str,
+    notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
+}
 
 #[derive(Serialize)]
 struct Request<T> {
@@ -48,8 +55,8 @@ struct Response<'a> {
     id: usize,
     #[serde(default)]
     error: Option<Error>,
-    #[serde(default, borrow)]
-    result: Option<&'a RawValue>,
+    #[serde(borrow)]
+    result: &'a RawValue,
 }
 
 #[derive(Serialize)]
@@ -67,29 +74,33 @@ struct InboundNotification<'a> {
     params: &'a RawValue,
 }
 
-#[derive(Deserialize)]
+#[derive(Debug, Deserialize)]
 struct Error {
     message: String,
 }
 
 impl LanguageServer {
-    pub fn rust(cx: &AppContext) -> Result<Arc<Self>> {
-        const BUNDLE: Option<&'static str> = option_env!("BUNDLE");
-        const TARGET: &'static str = env!("TARGET");
+    pub fn rust(root_path: &Path, cx: &AppContext) -> Result<Arc<Self>> {
+        const ZED_BUNDLE: Option<&'static str> = option_env!("ZED_BUNDLE");
+        const ZED_TARGET: &'static str = env!("ZED_TARGET");
 
-        let rust_analyzer_name = format!("rust-analyzer-{}", TARGET);
-        if BUNDLE.map_or(Ok(false), |b| b.parse())? {
+        let rust_analyzer_name = format!("rust-analyzer-{}", ZED_TARGET);
+        if ZED_BUNDLE.map_or(Ok(false), |b| b.parse())? {
             let rust_analyzer_path = cx
                 .platform()
                 .path_for_resource(Some(&rust_analyzer_name), None)?;
-            Self::new(&rust_analyzer_path, cx.background())
+            Self::new(root_path, &rust_analyzer_path, cx.background())
         } else {
-            Self::new(Path::new(&rust_analyzer_name), cx.background())
+            Self::new(root_path, Path::new(&rust_analyzer_name), cx.background())
         }
     }
 
-    pub fn new(path: &Path, background: &executor::Background) -> Result<Arc<Self>> {
-        let mut server = Command::new(path)
+    pub fn new(
+        root_path: &Path,
+        server_path: &Path,
+        background: &executor::Background,
+    ) -> Result<Arc<Self>> {
+        let mut server = Command::new(server_path)
             .stdin(Stdio::piped())
             .stdout(Stdio::piped())
             .stderr(Stdio::inherit())
@@ -97,9 +108,11 @@ impl LanguageServer {
         let mut stdin = server.stdin.take().unwrap();
         let mut stdout = BufReader::new(server.stdout.take().unwrap());
         let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
-        let response_handlers = Arc::new(Mutex::new(HashMap::<usize, ResponseHandler>::new()));
+        let notification_handlers = Arc::new(RwLock::new(HashMap::<_, NotificationHandler>::new()));
+        let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::new()));
         let _input_task = background.spawn(
             {
+                let notification_handlers = notification_handlers.clone();
                 let response_handlers = response_handlers.clone();
                 async move {
                     let mut buffer = Vec::new();
@@ -116,15 +129,21 @@ impl LanguageServer {
 
                         buffer.resize(message_len, 0);
                         stdout.read_exact(&mut buffer).await?;
-                        if let Ok(InboundNotification { .. }) = serde_json::from_slice(&buffer) {
+
+                        if let Ok(InboundNotification { method, params }) =
+                            serde_json::from_slice(&buffer)
+                        {
+                            if let Some(handler) = notification_handlers.read().get(method) {
+                                handler(params.get());
+                            }
                         } else if let Ok(Response { id, error, result }) =
                             serde_json::from_slice(&buffer)
                         {
                             if let Some(handler) = response_handlers.lock().remove(&id) {
-                                if let Some(result) = result {
-                                    handler(Ok(result.get()));
-                                } else if let Some(error) = error {
+                                if let Some(error) = error {
                                     handler(Err(error));
+                                } else {
+                                    handler(Ok(result.get()));
                                 }
                             }
                         } else {
@@ -142,6 +161,8 @@ impl LanguageServer {
             async move {
                 let mut content_len_buffer = Vec::new();
                 loop {
+                    content_len_buffer.clear();
+
                     let message = outbound_rx.recv().await?;
                     write!(content_len_buffer, "{}", message.len()).unwrap();
                     stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
@@ -155,6 +176,7 @@ impl LanguageServer {
 
         let (initialized_tx, initialized_rx) = barrier::channel();
         let this = Arc::new(Self {
+            notification_handlers,
             response_handlers,
             next_id: Default::default(),
             outbound_tx,
@@ -163,11 +185,13 @@ impl LanguageServer {
             initialized: initialized_rx,
         });
 
+        let root_uri =
+            lsp_types::Url::from_file_path(root_path).map_err(|_| anyhow!("invalid root path"))?;
         background
             .spawn({
                 let this = this.clone();
                 async move {
-                    this.init().log_err().await;
+                    this.init(root_uri).log_err().await;
                     drop(initialized_tx);
                 }
             })
@@ -176,45 +200,74 @@ impl LanguageServer {
         Ok(this)
     }
 
-    async fn init(self: Arc<Self>) -> Result<()> {
-        let res = self
-            .request_internal::<lsp_types::request::Initialize>(
-                lsp_types::InitializeParams {
-                    process_id: Default::default(),
-                    root_path: Default::default(),
-                    root_uri: Default::default(),
-                    initialization_options: Default::default(),
-                    capabilities: Default::default(),
-                    trace: Default::default(),
-                    workspace_folders: Default::default(),
-                    client_info: Default::default(),
-                    locale: Default::default(),
-                },
-                false,
-            )
-            .await?;
+    async fn init(self: Arc<Self>, root_uri: lsp_types::Url) -> Result<()> {
+        self.request_internal::<lsp_types::request::Initialize>(lsp_types::InitializeParams {
+            process_id: Default::default(),
+            root_path: Default::default(),
+            root_uri: Some(root_uri),
+            initialization_options: Default::default(),
+            capabilities: lsp_types::ClientCapabilities {
+                experimental: Some(json!({
+                    "serverStatusNotification": true,
+                })),
+                ..Default::default()
+            },
+            trace: Default::default(),
+            workspace_folders: Default::default(),
+            client_info: Default::default(),
+            locale: Default::default(),
+        })
+        .await?;
         self.notify_internal::<lsp_types::notification::Initialized>(
             lsp_types::InitializedParams {},
-            false,
         )
         .await?;
         Ok(())
     }
 
+    pub fn on_notification<T, F>(&self, f: F) -> Subscription
+    where
+        T: lsp_types::notification::Notification,
+        F: 'static + Send + Sync + Fn(T::Params),
+    {
+        let prev_handler = self.notification_handlers.write().insert(
+            T::METHOD,
+            Box::new(
+                move |notification| match serde_json::from_str(notification) {
+                    Ok(notification) => f(notification),
+                    Err(err) => log::error!("error parsing notification {}: {}", T::METHOD, err),
+                },
+            ),
+        );
+
+        assert!(
+            prev_handler.is_none(),
+            "registered multiple handlers for the same notification"
+        );
+
+        Subscription {
+            method: T::METHOD,
+            notification_handlers: self.notification_handlers.clone(),
+        }
+    }
+
     pub fn request<T: lsp_types::request::Request>(
-        self: &Arc<Self>,
+        self: Arc<Self>,
         params: T::Params,
     ) -> impl Future<Output = Result<T::Result>>
     where
         T::Result: 'static + Send,
     {
-        self.request_internal::<T>(params, true)
+        let this = self.clone();
+        async move {
+            this.initialized.clone().recv().await;
+            this.request_internal::<T>(params).await
+        }
     }
 
     fn request_internal<T: lsp_types::request::Request>(
         self: &Arc<Self>,
         params: T::Params,
-        wait_for_initialization: bool,
     ) -> impl Future<Output = Result<T::Result>>
     where
         T::Result: 'static + Send,
@@ -228,7 +281,7 @@ impl LanguageServer {
         })
         .unwrap();
         let mut response_handlers = self.response_handlers.lock();
-        let (tx, rx) = smol::channel::bounded(1);
+        let (mut tx, mut rx) = oneshot::channel();
         response_handlers.insert(
             id,
             Box::new(move |result| {
@@ -238,17 +291,14 @@ impl LanguageServer {
                     }
                     Err(error) => Err(anyhow!("{}", error.message)),
                 };
-                let _ = smol::block_on(tx.send(response));
+                let _ = tx.try_send(response);
             }),
         );
 
         let this = self.clone();
         async move {
-            if wait_for_initialization {
-                this.initialized.clone().recv().await;
-            }
             this.outbound_tx.send(message).await?;
-            rx.recv().await?
+            rx.recv().await.unwrap()
         }
     }
 
@@ -256,13 +306,16 @@ impl LanguageServer {
         self: &Arc<Self>,
         params: T::Params,
     ) -> impl Future<Output = Result<()>> {
-        self.notify_internal::<T>(params, true)
+        let this = self.clone();
+        async move {
+            this.initialized.clone().recv().await;
+            this.notify_internal::<T>(params).await
+        }
     }
 
     fn notify_internal<T: lsp_types::notification::Notification>(
         self: &Arc<Self>,
         params: T::Params,
-        wait_for_initialization: bool,
     ) -> impl Future<Output = Result<()>> {
         let message = serde_json::to_vec(&OutboundNotification {
             jsonrpc: JSON_RPC_VERSION,
@@ -273,22 +326,90 @@ impl LanguageServer {
 
         let this = self.clone();
         async move {
-            if wait_for_initialization {
-                this.initialized.clone().recv().await;
-            }
             this.outbound_tx.send(message).await?;
             Ok(())
         }
     }
 }
 
+impl Drop for Subscription {
+    fn drop(&mut self) {
+        self.notification_handlers.write().remove(self.method);
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;
     use gpui::TestAppContext;
+    use unindent::Unindent;
+    use util::test::temp_tree;
 
     #[gpui::test]
     async fn test_basic(cx: TestAppContext) {
-        let server = cx.read(|cx| LanguageServer::rust(cx).unwrap());
+        let root_dir = temp_tree(json!({
+            "Cargo.toml": r#"
+                [package]
+                name = "temp"
+                version = "0.1.0"
+                edition = "2018"
+            "#.unindent(),
+            "src": {
+                "lib.rs": r#"
+                    fn fun() {
+                        let hello = "world";
+                    }
+                "#.unindent()
+            }
+        }));
+
+        let server = cx.read(|cx| LanguageServer::rust(root_dir.path(), cx).unwrap());
+        server.next_idle_notification().await;
+
+        let hover = server
+            .request::<lsp_types::request::HoverRequest>(lsp_types::HoverParams {
+                text_document_position_params: lsp_types::TextDocumentPositionParams {
+                    text_document: lsp_types::TextDocumentIdentifier::new(
+                        lsp_types::Url::from_file_path(root_dir.path().join("src/lib.rs")).unwrap(),
+                    ),
+                    position: lsp_types::Position::new(1, 21),
+                },
+                work_done_progress_params: Default::default(),
+            })
+            .await
+            .unwrap()
+            .unwrap();
+        assert_eq!(
+            hover.contents,
+            lsp_types::HoverContents::Markup(lsp_types::MarkupContent {
+                kind: lsp_types::MarkupKind::Markdown,
+                value: "&str".to_string()
+            })
+        );
+    }
+
+    impl LanguageServer {
+        async fn next_idle_notification(self: &Arc<Self>) {
+            let (tx, rx) = channel::unbounded();
+            let _subscription =
+                self.on_notification::<ServerStatusNotification, _>(move |params| {
+                    if params.quiescent {
+                        tx.try_send(()).unwrap();
+                    }
+                });
+            let _ = rx.recv().await;
+        }
+    }
+
+    pub enum ServerStatusNotification {}
+
+    impl lsp_types::notification::Notification for ServerStatusNotification {
+        type Params = ServerStatusParams;
+        const METHOD: &'static str = "experimental/serverStatus";
+    }
+
+    #[derive(Deserialize, Serialize, PartialEq, Eq, Clone)]
+    pub struct ServerStatusParams {
+        pub quiescent: bool,
     }
 }

crates/project/src/lib.rs 🔗

@@ -7,8 +7,7 @@ use buffer::LanguageRegistry;
 use client::Client;
 use futures::Future;
 use fuzzy::{PathMatch, PathMatchCandidate, PathMatchCandidateSet};
-use gpui::{executor, AppContext, Entity, ModelContext, ModelHandle, Task};
-use lsp::LanguageServer;
+use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task};
 use std::{
     path::Path,
     sync::{atomic::AtomicBool, Arc},
@@ -24,7 +23,6 @@ pub struct Project {
     languages: Arc<LanguageRegistry>,
     client: Arc<client::Client>,
     fs: Arc<dyn Fs>,
-    language_server: Arc<LanguageServer>,
 }
 
 pub enum Event {
@@ -45,19 +43,13 @@ pub struct ProjectEntry {
 }
 
 impl Project {
-    pub fn new(
-        languages: Arc<LanguageRegistry>,
-        rpc: Arc<Client>,
-        fs: Arc<dyn Fs>,
-        cx: &AppContext,
-    ) -> Self {
+    pub fn new(languages: Arc<LanguageRegistry>, rpc: Arc<Client>, fs: Arc<dyn Fs>) -> Self {
         Self {
             worktrees: Default::default(),
             active_entry: None,
             languages,
             client: rpc,
             fs,
-            language_server: LanguageServer::rust(cx).unwrap(),
         }
     }
 
@@ -416,6 +408,6 @@ mod tests {
         let languages = Arc::new(LanguageRegistry::new());
         let fs = Arc::new(RealFs);
         let rpc = client::Client::new();
-        cx.add_model(|cx| Project::new(languages, rpc, fs, cx))
+        cx.add_model(|_| Project::new(languages, rpc, fs))
     }
 }

crates/project/src/worktree.rs 🔗

@@ -14,6 +14,7 @@ use gpui::{
     Task, UpgradeModelHandle, WeakModelHandle,
 };
 use lazy_static::lazy_static;
+use lsp::LanguageServer;
 use parking_lot::Mutex;
 use postage::{
     prelude::{Sink as _, Stream as _},
@@ -684,6 +685,7 @@ pub struct LocalWorktree {
     queued_operations: Vec<(u64, Operation)>,
     rpc: Arc<Client>,
     fs: Arc<dyn Fs>,
+    language_server: Arc<LanguageServer>,
 }
 
 #[derive(Default, Deserialize)]
@@ -721,6 +723,7 @@ impl LocalWorktree {
         let (scan_states_tx, scan_states_rx) = smol::channel::unbounded();
         let (mut last_scan_state_tx, last_scan_state_rx) = watch::channel_with(ScanState::Scanning);
         let tree = cx.add_model(move |cx: &mut ModelContext<Worktree>| {
+            let language_server = LanguageServer::rust(&abs_path, cx).unwrap();
             let mut snapshot = Snapshot {
                 id: cx.model_id(),
                 scan_id: 0,
@@ -796,6 +799,7 @@ impl LocalWorktree {
                 languages,
                 rpc,
                 fs,
+                language_server,
             };
 
             cx.spawn_weak(|this, mut cx| async move {

crates/project_panel/src/lib.rs 🔗

@@ -617,12 +617,11 @@ mod tests {
         )
         .await;
 
-        let project = cx.add_model(|cx| {
+        let project = cx.add_model(|_| {
             Project::new(
                 params.languages.clone(),
                 params.client.clone(),
                 params.fs.clone(),
-                cx,
             )
         });
         let root1 = project

crates/workspace/src/lib.rs 🔗

@@ -322,12 +322,11 @@ pub struct Workspace {
 
 impl Workspace {
     pub fn new(params: &WorkspaceParams, cx: &mut ViewContext<Self>) -> Self {
-        let project = cx.add_model(|cx| {
+        let project = cx.add_model(|_| {
             Project::new(
                 params.languages.clone(),
                 params.client.clone(),
                 params.fs.clone(),
-                cx,
             )
         });
         cx.observe(&project, |_, _, cx| cx.notify()).detach();

script/bundle 🔗

@@ -2,7 +2,7 @@
 
 set -e
 
-export BUNDLE=true
+export ZED_BUNDLE=true
 
 # Install cargo-bundle 0.5.0 if it's not already installed
 cargo install cargo-bundle --version 0.5.0