Explicitly shut down language servers when quitting the app

Nathan Sobo , Max Brunsfeld , and Antonio Scandurra created

Co-Authored-By: Max Brunsfeld <maxbrunsfeld@gmail.com>
Co-Authored-By: Antonio Scandurra <me@as-cii.com>

Change summary

crates/gpui/src/app.rs         | 102 ++++++++++++++++++++++++++++-------
crates/gpui/src/executor.rs    |  33 -----------
crates/lsp/src/lib.rs          |  95 +++++++++++++++++++-------------
crates/project/src/worktree.rs |  20 +++++++
crates/zed/src/main.rs         |  13 ---
5 files changed, 161 insertions(+), 102 deletions(-)

Detailed changes

crates/gpui/src/app.rs 🔗

@@ -23,6 +23,7 @@ use std::{
     mem,
     ops::{Deref, DerefMut},
     path::{Path, PathBuf},
+    pin::Pin,
     rc::{self, Rc},
     sync::{
         atomic::{AtomicUsize, Ordering::SeqCst},
@@ -35,6 +36,12 @@ pub trait Entity: 'static {
     type Event;
 
     fn release(&mut self, _: &mut MutableAppContext) {}
+    fn app_will_quit(
+        &mut self,
+        _: &mut MutableAppContext,
+    ) -> Option<Pin<Box<dyn 'static + Future<Output = ()>>>> {
+        None
+    }
 }
 
 pub trait View: Entity + Sized {
@@ -198,8 +205,6 @@ pub struct App(Rc<RefCell<MutableAppContext>>);
 #[derive(Clone)]
 pub struct AsyncAppContext(Rc<RefCell<MutableAppContext>>);
 
-pub struct BackgroundAppContext(*const RefCell<MutableAppContext>);
-
 #[derive(Clone)]
 pub struct TestAppContext {
     cx: Rc<RefCell<MutableAppContext>>,
@@ -220,20 +225,29 @@ impl App {
             asset_source,
         ))));
 
-        let cx = app.0.clone();
-        foreground_platform.on_menu_command(Box::new(move |action| {
-            let mut cx = cx.borrow_mut();
-            if let Some(key_window_id) = cx.cx.platform.key_window_id() {
-                if let Some((presenter, _)) = cx.presenters_and_platform_windows.get(&key_window_id)
-                {
-                    let presenter = presenter.clone();
-                    let path = presenter.borrow().dispatch_path(cx.as_ref());
-                    cx.dispatch_action_any(key_window_id, &path, action);
+        foreground_platform.on_quit(Box::new({
+            let cx = app.0.clone();
+            move || {
+                cx.borrow_mut().quit();
+            }
+        }));
+        foreground_platform.on_menu_command(Box::new({
+            let cx = app.0.clone();
+            move |action| {
+                let mut cx = cx.borrow_mut();
+                if let Some(key_window_id) = cx.cx.platform.key_window_id() {
+                    if let Some((presenter, _)) =
+                        cx.presenters_and_platform_windows.get(&key_window_id)
+                    {
+                        let presenter = presenter.clone();
+                        let path = presenter.borrow().dispatch_path(cx.as_ref());
+                        cx.dispatch_action_any(key_window_id, &path, action);
+                    } else {
+                        cx.dispatch_global_action_any(action);
+                    }
                 } else {
                     cx.dispatch_global_action_any(action);
                 }
-            } else {
-                cx.dispatch_global_action_any(action);
             }
         }));
 
@@ -751,6 +765,39 @@ impl MutableAppContext {
         App(self.weak_self.as_ref().unwrap().upgrade().unwrap())
     }
 
+    pub fn quit(&mut self) {
+        let mut futures = Vec::new();
+        for model_id in self.cx.models.keys().copied().collect::<Vec<_>>() {
+            let mut model = self.cx.models.remove(&model_id).unwrap();
+            futures.extend(model.app_will_quit(self));
+            self.cx.models.insert(model_id, model);
+        }
+
+        for view_id in self.cx.views.keys().copied().collect::<Vec<_>>() {
+            let mut view = self.cx.views.remove(&view_id).unwrap();
+            futures.extend(view.app_will_quit(self));
+            self.cx.views.insert(view_id, view);
+        }
+
+        self.remove_all_windows();
+
+        let futures = futures::future::join_all(futures);
+        if self
+            .background
+            .block_with_timeout(Duration::from_millis(100), futures)
+            .is_err()
+        {
+            log::error!("timed out waiting on app_will_quit");
+        }
+    }
+
+    fn remove_all_windows(&mut self) {
+        for (window_id, _) in self.cx.windows.drain() {
+            self.presenters_and_platform_windows.remove(&window_id);
+        }
+        self.remove_dropped_entities();
+    }
+
     pub fn platform(&self) -> Arc<dyn platform::Platform> {
         self.cx.platform.clone()
     }
@@ -1230,13 +1277,6 @@ impl MutableAppContext {
         self.remove_dropped_entities();
     }
 
-    pub fn remove_all_windows(&mut self) {
-        for (window_id, _) in self.cx.windows.drain() {
-            self.presenters_and_platform_windows.remove(&window_id);
-        }
-        self.remove_dropped_entities();
-    }
-
     fn open_platform_window(&mut self, window_id: usize, window_options: WindowOptions) {
         let mut window =
             self.cx
@@ -1898,6 +1938,10 @@ pub trait AnyModel {
     fn as_any(&self) -> &dyn Any;
     fn as_any_mut(&mut self) -> &mut dyn Any;
     fn release(&mut self, cx: &mut MutableAppContext);
+    fn app_will_quit(
+        &mut self,
+        cx: &mut MutableAppContext,
+    ) -> Option<Pin<Box<dyn 'static + Future<Output = ()>>>>;
 }
 
 impl<T> AnyModel for T
@@ -1915,12 +1959,23 @@ where
     fn release(&mut self, cx: &mut MutableAppContext) {
         self.release(cx);
     }
+
+    fn app_will_quit(
+        &mut self,
+        cx: &mut MutableAppContext,
+    ) -> Option<Pin<Box<dyn 'static + Future<Output = ()>>>> {
+        self.app_will_quit(cx)
+    }
 }
 
 pub trait AnyView {
     fn as_any(&self) -> &dyn Any;
     fn as_any_mut(&mut self) -> &mut dyn Any;
     fn release(&mut self, cx: &mut MutableAppContext);
+    fn app_will_quit(
+        &mut self,
+        cx: &mut MutableAppContext,
+    ) -> Option<Pin<Box<dyn 'static + Future<Output = ()>>>>;
     fn ui_name(&self) -> &'static str;
     fn render<'a>(
         &mut self,
@@ -1951,6 +2006,13 @@ where
         self.release(cx);
     }
 
+    fn app_will_quit(
+        &mut self,
+        cx: &mut MutableAppContext,
+    ) -> Option<Pin<Box<dyn 'static + Future<Output = ()>>>> {
+        self.app_will_quit(cx)
+    }
+
     fn ui_name(&self) -> &'static str {
         T::ui_name()
     }

crates/gpui/src/executor.rs 🔗

@@ -40,11 +40,9 @@ pub enum Foreground {
 pub enum Background {
     Deterministic {
         executor: Arc<Deterministic>,
-        critical_tasks: Mutex<Vec<Task<()>>>,
     },
     Production {
         executor: Arc<smol::Executor<'static>>,
-        critical_tasks: Mutex<Vec<Task<()>>>,
         _stop: channel::Sender<()>,
     },
 }
@@ -504,7 +502,6 @@ impl Background {
 
         Self::Production {
             executor,
-            critical_tasks: Default::default(),
             _stop: stop.0,
         }
     }
@@ -526,31 +523,6 @@ impl Background {
         Task::send(any_task)
     }
 
-    pub fn spawn_critical<T, F>(&self, future: F)
-    where
-        T: 'static + Send,
-        F: Send + Future<Output = T> + 'static,
-    {
-        let task = self.spawn(async move {
-            future.await;
-        });
-        match self {
-            Self::Production { critical_tasks, .. }
-            | Self::Deterministic { critical_tasks, .. } => critical_tasks.lock().push(task),
-        }
-    }
-
-    pub fn block_on_critical_tasks(&self, timeout: Duration) -> bool {
-        match self {
-            Background::Production { critical_tasks, .. }
-            | Self::Deterministic { critical_tasks, .. } => {
-                let tasks = mem::take(&mut *critical_tasks.lock());
-                self.block_with_timeout(timeout, futures::future::join_all(tasks))
-                    .is_ok()
-            }
-        }
-    }
-
     pub fn block_with_timeout<F, T>(
         &self,
         timeout: Duration,
@@ -617,10 +589,7 @@ pub fn deterministic(seed: u64) -> (Rc<Foreground>, Arc<Background>) {
     let executor = Arc::new(Deterministic::new(seed));
     (
         Rc::new(Foreground::Deterministic(executor.clone())),
-        Arc::new(Background::Deterministic {
-            executor,
-            critical_tasks: Default::default(),
-        }),
+        Arc::new(Background::Deterministic { executor }),
     )
 }
 

crates/lsp/src/lib.rs 🔗

@@ -33,13 +33,13 @@ type ResponseHandler = Box<dyn Send + FnOnce(Result<&str, Error>)>;
 
 pub struct LanguageServer {
     next_id: AtomicUsize,
-    outbound_tx: channel::Sender<Vec<u8>>,
+    outbound_tx: RwLock<Option<channel::Sender<Vec<u8>>>>,
     notification_handlers: Arc<RwLock<HashMap<&'static str, NotificationHandler>>>,
     response_handlers: Arc<Mutex<HashMap<usize, ResponseHandler>>>,
     executor: Arc<executor::Background>,
-    io_tasks: Option<(Task<Option<()>>, Task<Option<()>>)>,
+    io_tasks: Mutex<Option<(Task<Option<()>>, Task<Option<()>>)>>,
     initialized: barrier::Receiver,
-    output_done_rx: Option<barrier::Receiver>,
+    output_done_rx: Mutex<Option<barrier::Receiver>>,
 }
 
 pub struct Subscription {
@@ -198,11 +198,11 @@ impl LanguageServer {
             notification_handlers,
             response_handlers,
             next_id: Default::default(),
-            outbound_tx,
+            outbound_tx: RwLock::new(Some(outbound_tx)),
             executor: executor.clone(),
-            io_tasks: Some((input_task, output_task)),
+            io_tasks: Mutex::new(Some((input_task, output_task))),
             initialized: initialized_rx,
-            output_done_rx: Some(output_done_rx),
+            output_done_rx: Mutex::new(Some(output_done_rx)),
         });
 
         let root_uri =
@@ -240,20 +240,45 @@ impl LanguageServer {
         };
 
         let this = self.clone();
-        Self::request_internal::<lsp_types::request::Initialize>(
+        let request = Self::request_internal::<lsp_types::request::Initialize>(
             &this.next_id,
             &this.response_handlers,
-            &this.outbound_tx,
+            this.outbound_tx.read().as_ref(),
             params,
-        )
-        .await?;
+        );
+        request.await?;
         Self::notify_internal::<lsp_types::notification::Initialized>(
-            &this.outbound_tx,
+            this.outbound_tx.read().as_ref(),
             lsp_types::InitializedParams {},
         )?;
         Ok(())
     }
 
+    pub fn shutdown(&self) -> Option<impl 'static + Send + Future<Output = Result<()>>> {
+        if let Some(tasks) = self.io_tasks.lock().take() {
+            let response_handlers = self.response_handlers.clone();
+            let outbound_tx = self.outbound_tx.write().take();
+            let next_id = AtomicUsize::new(self.next_id.load(SeqCst));
+            let mut output_done = self.output_done_rx.lock().take().unwrap();
+            Some(async move {
+                Self::request_internal::<lsp_types::request::Shutdown>(
+                    &next_id,
+                    &response_handlers,
+                    outbound_tx.as_ref(),
+                    (),
+                )
+                .await?;
+                Self::notify_internal::<lsp_types::notification::Exit>(outbound_tx.as_ref(), ())?;
+                drop(outbound_tx);
+                output_done.recv().await;
+                drop(tasks);
+                Ok(())
+            })
+        } else {
+            None
+        }
+    }
+
     pub fn on_notification<T, F>(&self, f: F) -> Subscription
     where
         T: lsp_types::notification::Notification,
@@ -293,7 +318,7 @@ impl LanguageServer {
             Self::request_internal::<T>(
                 &this.next_id,
                 &this.response_handlers,
-                &this.outbound_tx,
+                this.outbound_tx.read().as_ref(),
                 params,
             )
             .await
@@ -303,9 +328,9 @@ impl LanguageServer {
     fn request_internal<T: lsp_types::request::Request>(
         next_id: &AtomicUsize,
         response_handlers: &Mutex<HashMap<usize, ResponseHandler>>,
-        outbound_tx: &channel::Sender<Vec<u8>>,
+        outbound_tx: Option<&channel::Sender<Vec<u8>>>,
         params: T::Params,
-    ) -> impl Future<Output = Result<T::Result>>
+    ) -> impl 'static + Future<Output = Result<T::Result>>
     where
         T::Result: 'static + Send,
     {
@@ -332,7 +357,15 @@ impl LanguageServer {
             }),
         );
 
-        let send = outbound_tx.try_send(message);
+        let send = outbound_tx
+            .as_ref()
+            .ok_or_else(|| {
+                anyhow!("tried to send a request to a language server that has been shut down")
+            })
+            .and_then(|outbound_tx| {
+                outbound_tx.try_send(message)?;
+                Ok(())
+            });
         async move {
             send?;
             rx.recv().await.unwrap()
@@ -346,13 +379,13 @@ impl LanguageServer {
         let this = self.clone();
         async move {
             this.initialized.clone().recv().await;
-            Self::notify_internal::<T>(&this.outbound_tx, params)?;
+            Self::notify_internal::<T>(this.outbound_tx.read().as_ref(), params)?;
             Ok(())
         }
     }
 
     fn notify_internal<T: lsp_types::notification::Notification>(
-        outbound_tx: &channel::Sender<Vec<u8>>,
+        outbound_tx: Option<&channel::Sender<Vec<u8>>>,
         params: T::Params,
     ) -> Result<()> {
         let message = serde_json::to_vec(&Notification {
@@ -361,6 +394,9 @@ impl LanguageServer {
             params,
         })
         .unwrap();
+        let outbound_tx = outbound_tx
+            .as_ref()
+            .ok_or_else(|| anyhow!("tried to notify a language server that has been shut down"))?;
         outbound_tx.try_send(message)?;
         Ok(())
     }
@@ -368,28 +404,9 @@ impl LanguageServer {
 
 impl Drop for LanguageServer {
     fn drop(&mut self) {
-        let tasks = self.io_tasks.take();
-        let response_handlers = self.response_handlers.clone();
-        let outbound_tx = self.outbound_tx.clone();
-        let next_id = AtomicUsize::new(self.next_id.load(SeqCst));
-        let mut output_done = self.output_done_rx.take().unwrap();
-        self.executor.spawn_critical(
-            async move {
-                Self::request_internal::<lsp_types::request::Shutdown>(
-                    &next_id,
-                    &response_handlers,
-                    &outbound_tx,
-                    (),
-                )
-                .await?;
-                Self::notify_internal::<lsp_types::notification::Exit>(&outbound_tx, ())?;
-                drop(outbound_tx);
-                output_done.recv().await;
-                drop(tasks);
-                Ok(())
-            }
-            .log_err(),
-        )
+        if let Some(shutdown) = self.shutdown() {
+            self.executor.spawn(shutdown).detach();
+        }
     }
 }
 

crates/project/src/worktree.rs 🔗

@@ -20,6 +20,7 @@ use postage::{
     prelude::{Sink as _, Stream as _},
     watch,
 };
+
 use serde::Deserialize;
 use smol::channel::{self, Sender};
 use std::{
@@ -90,6 +91,25 @@ impl Entity for Worktree {
             }
         }
     }
+
+    fn app_will_quit(
+        &mut self,
+        _: &mut MutableAppContext,
+    ) -> Option<std::pin::Pin<Box<dyn 'static + Future<Output = ()>>>> {
+        use futures::FutureExt;
+
+        if let Some(server) = self.language_server() {
+            if let Some(shutdown) = server.shutdown() {
+                return Some(
+                    async move {
+                        shutdown.await.log_err();
+                    }
+                    .boxed(),
+                );
+            }
+        }
+        None
+    }
 }
 
 impl Worktree {

crates/zed/src/main.rs 🔗

@@ -7,7 +7,7 @@ use gpui::AssetSource;
 use log::LevelFilter;
 use parking_lot::Mutex;
 use simplelog::SimpleLogger;
-use std::{fs, path::PathBuf, sync::Arc, time::Duration};
+use std::{fs, path::PathBuf, sync::Arc};
 use theme::ThemeRegistry;
 use workspace::{self, settings, OpenNew};
 use zed::{self, assets::Assets, fs::RealFs, language, menus, AppState, OpenParams, OpenPaths};
@@ -29,16 +29,7 @@ fn main() {
     let languages = Arc::new(language::build_language_registry());
     languages.set_theme(&settings.borrow().theme.editor.syntax);
 
-    app.on_quit(|cx| {
-        cx.remove_all_windows();
-        let did_finish = cx
-            .background()
-            .block_on_critical_tasks(Duration::from_millis(100));
-        if !did_finish {
-            log::error!("timed out on quit before critical tasks finished");
-        }
-    })
-    .run(move |cx| {
+    app.run(move |cx| {
         let client = client::Client::new();
         let http = http::client();
         let user_store = cx.add_model(|cx| UserStore::new(client.clone(), http.clone(), cx));