Delay quit until language servers are gracefully shut down

Max Brunsfeld , Antonio Scandurra , and Nathan Sobo created

Co-Authored-By: Antonio Scandurra <me@as-cii.com>
Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

crates/language/src/language.rs |   2 
crates/language/src/tests.rs    |   2 
crates/lsp/src/lib.rs           | 119 +++++++++++++++++++++++++----------
crates/project/src/worktree.rs  |   2 
crates/zed/src/main.rs          |  12 ++
5 files changed, 98 insertions(+), 39 deletions(-)

Detailed changes

crates/language/src/language.rs 🔗

@@ -124,7 +124,7 @@ impl Language {
             } else {
                 Path::new(&config.binary).to_path_buf()
             };
-            lsp::LanguageServer::new(&binary_path, root_path, cx.background()).map(Some)
+            lsp::LanguageServer::new(&binary_path, root_path, cx.background().clone()).map(Some)
         } else {
             Ok(None)
         }

crates/language/src/tests.rs 🔗

@@ -409,7 +409,7 @@ fn test_autoindent_adjusts_lines_when_only_text_changes(cx: &mut MutableAppConte
 
 #[gpui::test]
 async fn test_diagnostics(mut cx: gpui::TestAppContext) {
-    let (language_server, mut fake) = lsp::LanguageServer::fake(&cx.background()).await;
+    let (language_server, mut fake) = lsp::LanguageServer::fake(cx.background()).await;
 
     let text = "
         fn a() { A }

crates/lsp/src/lib.rs 🔗

@@ -36,9 +36,10 @@ pub struct LanguageServer {
     outbound_tx: channel::Sender<Vec<u8>>,
     notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
     response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
-    _input_task: Task<Option<()>>,
-    _output_task: Task<Option<()>>,
+    executor: Arc<executor::Background>,
+    io_tasks: Option<(Task<Option<()>>, Task<Option<()>>)>,
     initialized: barrier::Receiver,
+    output_done_rx: Option<barrier::Receiver>,
 }
 
 pub struct Subscription {
@@ -89,7 +90,7 @@ impl LanguageServer {
     pub fn new(
         binary_path: &Path,
         root_path: &Path,
-        background: &executor::Background,
+        background: Arc<executor::Background>,
     ) -> Result<Arc<Self>> {
         let mut server = Command::new(binary_path)
             .stdin(Stdio::piped())
@@ -105,7 +106,7 @@ impl LanguageServer {
         stdin: Stdin,
         stdout: Stdout,
         root_path: &Path,
-        background: &executor::Background,
+        executor: Arc<executor::Background>,
     ) -> Result<Arc<Self>>
     where
         Stdin: AsyncWrite + Unpin + Send + 'static,
@@ -116,7 +117,7 @@ impl LanguageServer {
         let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
         let notification_handlers = Arc::new(RwLock::new(HashMap::<_, NotificationHandler>::new()));
         let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::new()));
-        let _input_task = background.spawn(
+        let input_task = executor.spawn(
             {
                 let notification_handlers = notification_handlers.clone();
                 let response_handlers = response_handlers.clone();
@@ -171,13 +172,12 @@ impl LanguageServer {
             }
             .log_err(),
         );
-        let _output_task = background.spawn(
+        let (output_done_tx, output_done_rx) = barrier::channel();
+        let output_task = executor.spawn(
             async move {
                 let mut content_len_buffer = Vec::new();
-                loop {
+                while let Ok(message) = outbound_rx.recv().await {
                     content_len_buffer.clear();
-
-                    let message = outbound_rx.recv().await?;
                     write!(content_len_buffer, "{}", message.len()).unwrap();
                     stdin.write_all(CONTENT_LEN_HEADER.as_bytes()).await?;
                     stdin.write_all(&content_len_buffer).await?;
@@ -185,6 +185,8 @@ impl LanguageServer {
                     stdin.write_all(&message).await?;
                     stdin.flush().await?;
                 }
+                drop(output_done_tx);
+                Ok(())
             }
             .log_err(),
         );
@@ -195,14 +197,15 @@ impl LanguageServer {
             response_handlers,
             next_id: Default::default(),
             outbound_tx,
-            _input_task,
-            _output_task,
+            executor: executor.clone(),
+            io_tasks: Some((input_task, output_task)),
             initialized: initialized_rx,
+            output_done_rx: Some(output_done_rx),
         });
 
         let root_uri =
             lsp_types::Url::from_file_path(root_path).map_err(|_| anyhow!("invalid root path"))?;
-        background
+        executor
             .spawn({
                 let this = this.clone();
                 async move {
@@ -234,12 +237,18 @@ impl LanguageServer {
             locale: Default::default(),
         };
 
-        self.request_internal::<lsp_types::request::Initialize>(params)
-            .await?;
-        self.notify_internal::<lsp_types::notification::Initialized>(
-            lsp_types::InitializedParams {},
+        let this = self.clone();
+        Self::request_internal::<lsp_types::request::Initialize>(
+            &this.next_id,
+            &this.response_handlers,
+            &this.outbound_tx,
+            params,
         )
         .await?;
+        Self::notify_internal::<lsp_types::notification::Initialized>(
+            &this.outbound_tx,
+            lsp_types::InitializedParams {},
+        )?;
         Ok(())
     }
 
@@ -279,18 +288,26 @@ impl LanguageServer {
         let this = self.clone();
         async move {
             this.initialized.clone().recv().await;
-            this.request_internal::<T>(params).await
+            Self::request_internal::<T>(
+                &this.next_id,
+                &this.response_handlers,
+                &this.outbound_tx,
+                params,
+            )
+            .await
         }
     }
 
     fn request_internal<T: lsp_types::request::Request>(
-        self: &Arc<Self>,
+        next_id: &AtomicUsize,
+        response_handlers: &Mutex<HashMap<usize, ResponseHandler>>,
+        outbound_tx: &channel::Sender<Vec<u8>>,
         params: T::Params,
     ) -> impl Future<Output = Result<T::Result>>
     where
         T::Result: 'static + Send,
     {
-        let id = self.next_id.fetch_add(1, SeqCst);
+        let id = next_id.fetch_add(1, SeqCst);
         let message = serde_json::to_vec(&Request {
             jsonrpc: JSON_RPC_VERSION,
             id,
@@ -298,7 +315,7 @@ impl LanguageServer {
             params,
         })
         .unwrap();
-        let mut response_handlers = self.response_handlers.lock();
+        let mut response_handlers = response_handlers.lock();
         let (mut tx, mut rx) = oneshot::channel();
         response_handlers.insert(
             id,
@@ -313,9 +330,9 @@ impl LanguageServer {
             }),
         );
 
-        let this = self.clone();
+        let send = outbound_tx.try_send(message);
         async move {
-            this.outbound_tx.send(message).await?;
+            send?;
             rx.recv().await.unwrap()
         }
     }
@@ -327,26 +344,50 @@ impl LanguageServer {
         let this = self.clone();
         async move {
             this.initialized.clone().recv().await;
-            this.notify_internal::<T>(params).await
+            Self::notify_internal::<T>(&this.outbound_tx, params)?;
+            Ok(())
         }
     }
 
     fn notify_internal<T: lsp_types::notification::Notification>(
-        self: &Arc<Self>,
+        outbound_tx: &channel::Sender<Vec<u8>>,
         params: T::Params,
-    ) -> impl Future<Output = Result<()>> {
+    ) -> Result<()> {
         let message = serde_json::to_vec(&Notification {
             jsonrpc: JSON_RPC_VERSION,
             method: T::METHOD,
             params,
         })
         .unwrap();
+        outbound_tx.try_send(message)?;
+        Ok(())
+    }
+}
 
-        let this = self.clone();
-        async move {
-            this.outbound_tx.send(message).await?;
-            Ok(())
-        }
+impl Drop for LanguageServer {
+    fn drop(&mut self) {
+        let tasks = self.io_tasks.take();
+        let response_handlers = self.response_handlers.clone();
+        let outbound_tx = self.outbound_tx.clone();
+        let next_id = AtomicUsize::new(self.next_id.load(SeqCst));
+        let mut output_done = self.output_done_rx.take().unwrap();
+        self.executor.spawn_critical(
+            async move {
+                Self::request_internal::<lsp_types::request::Shutdown>(
+                    &next_id,
+                    &response_handlers,
+                    &outbound_tx,
+                    (),
+                )
+                .await?;
+                Self::notify_internal::<lsp_types::notification::Exit>(&outbound_tx, ())?;
+                drop(outbound_tx);
+                output_done.recv().await;
+                drop(tasks);
+                Ok(())
+            }
+            .log_err(),
+        )
     }
 }
 
@@ -377,7 +418,7 @@ pub struct RequestId<T> {
 
 #[cfg(any(test, feature = "test-support"))]
 impl LanguageServer {
-    pub async fn fake(executor: &executor::Background) -> (Arc<Self>, FakeLanguageServer) {
+    pub async fn fake(executor: Arc<executor::Background>) -> (Arc<Self>, FakeLanguageServer) {
         let stdin = async_pipe::pipe();
         let stdout = async_pipe::pipe();
         let mut fake = FakeLanguageServer {
@@ -512,8 +553,12 @@ mod tests {
             lsp_types::Url::from_file_path(root_dir.path().join("src/lib.rs")).unwrap();
 
         let server = cx.read(|cx| {
-            LanguageServer::new(Path::new("rust-analyzer"), root_dir.path(), cx.background())
-                .unwrap()
+            LanguageServer::new(
+                Path::new("rust-analyzer"),
+                root_dir.path(),
+                cx.background().clone(),
+            )
+            .unwrap()
         });
         server.next_idle_notification().await;
 
@@ -555,7 +600,7 @@ mod tests {
     async fn test_fake(cx: TestAppContext) {
         SimpleLogger::init(log::LevelFilter::Info, Default::default()).unwrap();
 
-        let (server, mut fake) = LanguageServer::fake(&cx.background()).await;
+        let (server, mut fake) = LanguageServer::fake(cx.background()).await;
 
         let (message_tx, message_rx) = channel::unbounded();
         let (diagnostics_tx, diagnostics_rx) = channel::unbounded();
@@ -606,6 +651,12 @@ mod tests {
             diagnostics_rx.recv().await.unwrap().uri.as_str(),
             "file://b/c"
         );
+
+        drop(server);
+        let (shutdown_request, _) = fake.receive_request::<lsp_types::request::Shutdown>().await;
+        fake.respond(shutdown_request, ()).await;
+        fake.receive_notification::<lsp_types::notification::Exit>()
+            .await;
     }
 
     impl LanguageServer {

crates/project/src/worktree.rs 🔗

@@ -3508,7 +3508,7 @@ mod tests {
 
     #[gpui::test]
     async fn test_language_server_diagnostics(mut cx: gpui::TestAppContext) {
-        let (language_server, mut fake_lsp) = LanguageServer::fake(&cx.background()).await;
+        let (language_server, mut fake_lsp) = LanguageServer::fake(cx.background()).await;
         let dir = temp_tree(json!({
             "a.rs": "fn a() { A }",
             "b.rs": "const y: i32 = 1",

crates/zed/src/main.rs 🔗

@@ -7,7 +7,7 @@ use gpui::AssetSource;
 use log::LevelFilter;
 use parking_lot::Mutex;
 use simplelog::SimpleLogger;
-use std::{fs, path::PathBuf, sync::Arc};
+use std::{fs, path::PathBuf, sync::Arc, time::Duration};
 use theme::ThemeRegistry;
 use workspace::{self, settings, OpenNew};
 use zed::{self, assets::Assets, fs::RealFs, language, menus, AppState, OpenParams, OpenPaths};
@@ -29,7 +29,15 @@ fn main() {
     let languages = Arc::new(language::build_language_registry());
     languages.set_theme(&settings.borrow().theme.editor.syntax);
 
-    app.run(move |cx| {
+    app.on_quit(|cx| {
+        let did_finish = cx
+            .background()
+            .block_on_critical_tasks(Duration::from_millis(100));
+        if !did_finish {
+            log::error!("timed out on quit before critical tasks finished");
+        }
+    })
+    .run(move |cx| {
         let client = client::Client::new();
         let http = http::client();
         let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http.clone(), cx));