Prevent making further requests after language server shut down

Antonio Scandurra and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

crates/lsp/src/lsp.rs | 71 ++++++++++++++++++++++++++------------------
1 file changed, 42 insertions(+), 29 deletions(-)

Detailed changes

crates/lsp/src/lsp.rs 🔗

@@ -40,7 +40,7 @@ pub struct LanguageServer {
     name: String,
     capabilities: ServerCapabilities,
     notification_handlers: Arc<Mutex<HashMap<&'static str, NotificationHandler>>>,
-    response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
+    response_handlers: Arc<Mutex<Option<HashMap<usize, ResponseHandler>>>>,
     executor: Arc<executor::Background>,
     #[allow(clippy::type_complexity)]
     io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
@@ -170,12 +170,18 @@ impl LanguageServer {
         let (outbound_tx, outbound_rx) = channel::unbounded::<Vec<u8>>();
         let notification_handlers =
             Arc::new(Mutex::new(HashMap::<_, NotificationHandler>::default()));
-        let response_handlers = Arc::new(Mutex::new(HashMap::<_, ResponseHandler>::default()));
+        let response_handlers =
+            Arc::new(Mutex::new(Some(HashMap::<_, ResponseHandler>::default())));
         let input_task = cx.spawn(|cx| {
             let notification_handlers = notification_handlers.clone();
             let response_handlers = response_handlers.clone();
             async move {
-                let _clear_response_handlers = ClearResponseHandlers(response_handlers.clone());
+                let _clear_response_handlers = util::defer({
+                    let response_handlers = response_handlers.clone();
+                    move || {
+                        response_handlers.lock().take();
+                    }
+                });
                 let mut buffer = Vec::new();
                 loop {
                     buffer.clear();
@@ -200,7 +206,11 @@ impl LanguageServer {
                     } else if let Ok(AnyResponse { id, error, result }) =
                         serde_json::from_slice(&buffer)
                     {
-                        if let Some(handler) = response_handlers.lock().remove(&id) {
+                        if let Some(handler) = response_handlers
+                            .lock()
+                            .as_mut()
+                            .and_then(|handlers| handlers.remove(&id))
+                        {
                             if let Some(error) = error {
                                 handler(Err(error));
                             } else if let Some(result) = result {
@@ -226,7 +236,12 @@ impl LanguageServer {
         let output_task = cx.background().spawn({
             let response_handlers = response_handlers.clone();
             async move {
-                let _clear_response_handlers = ClearResponseHandlers(response_handlers);
+                let _clear_response_handlers = util::defer({
+                    let response_handlers = response_handlers.clone();
+                    move || {
+                        response_handlers.lock().take();
+                    }
+                });
                 let mut content_len_buffer = Vec::new();
                 while let Ok(message) = outbound_rx.recv().await {
                     log::trace!("outgoing message:{}", String::from_utf8_lossy(&message));
@@ -366,7 +381,7 @@ impl LanguageServer {
                 async move {
                     log::debug!("language server shutdown started");
                     shutdown_request.await?;
-                    response_handlers.lock().clear();
+                    response_handlers.lock().take();
                     exit?;
                     output_done.recv().await;
                     log::debug!("language server shutdown finished");
@@ -521,7 +536,7 @@ impl LanguageServer {
 
     fn request_internal<T: request::Request>(
         next_id: &AtomicUsize,
-        response_handlers: &Mutex<HashMap<usize, ResponseHandler>>,
+        response_handlers: &Mutex<Option<HashMap<usize, ResponseHandler>>>,
         outbound_tx: &channel::Sender<Vec<u8>>,
         params: T::Params,
     ) -> impl 'static + Future<Output = Result<T::Result>>
@@ -537,25 +552,31 @@ impl LanguageServer {
         })
         .unwrap();
 
+        let (tx, rx) = oneshot::channel();
+        let handle_response = response_handlers
+            .lock()
+            .as_mut()
+            .ok_or_else(|| anyhow!("server shut down"))
+            .map(|handlers| {
+                handlers.insert(
+                    id,
+                    Box::new(move |result| {
+                        let response = match result {
+                            Ok(response) => serde_json::from_str(response)
+                                .context("failed to deserialize response"),
+                            Err(error) => Err(anyhow!("{}", error.message)),
+                        };
+                        let _ = tx.send(response);
+                    }),
+                );
+            });
+
         let send = outbound_tx
             .try_send(message)
             .context("failed to write to language server's stdin");
 
-        let (tx, rx) = oneshot::channel();
-        response_handlers.lock().insert(
-            id,
-            Box::new(move |result| {
-                let response = match result {
-                    Ok(response) => {
-                        serde_json::from_str(response).context("failed to deserialize response")
-                    }
-                    Err(error) => Err(anyhow!("{}", error.message)),
-                };
-                let _ = tx.send(response);
-            }),
-        );
-
         async move {
+            handle_response?;
             send?;
             rx.await?
         }
@@ -762,14 +783,6 @@ impl FakeLanguageServer {
     }
 }
 
-struct ClearResponseHandlers(Arc<Mutex<HashMap<usize, ResponseHandler>>>);
-
-impl Drop for ClearResponseHandlers {
-    fn drop(&mut self) {
-        self.0.lock().clear();
-    }
-}
-
 #[cfg(test)]
 mod tests {
     use super::*;