Merge pull request #763 from zed-industries/inconsistent-diagnostic-state

Antonio Scandurra created

Fix bad diagnostic state when restarting a language server w/ a running diagnostic task

Change summary

crates/diagnostics/src/items.rs |  2 
crates/gpui/src/app.rs          | 31 ++++------
crates/project/src/project.rs   | 99 ++++++++++++++++++++++++++++++++--
3 files changed, 105 insertions(+), 27 deletions(-)

Detailed changes

crates/diagnostics/src/items.rs 🔗

@@ -3,8 +3,8 @@ use gpui::{
     elements::*, platform::CursorStyle, Entity, ModelHandle, RenderContext, View, ViewContext,
 };
 use project::Project;
-use workspace::{StatusItemView};
 use settings::Settings;
+use workspace::StatusItemView;
 
 pub struct DiagnosticSummary {
     summary: project::DiagnosticSummary,

crates/gpui/src/app.rs 🔗

@@ -3392,12 +3392,10 @@ impl<T: Entity> ModelHandle<T> {
 
     #[cfg(any(test, feature = "test-support"))]
     pub fn next_notification(&self, cx: &TestAppContext) -> impl Future<Output = ()> {
-        use postage::prelude::{Sink as _, Stream as _};
-
-        let (mut tx, mut rx) = postage::mpsc::channel(1);
+        let (tx, mut rx) = futures::channel::mpsc::unbounded();
         let mut cx = cx.cx.borrow_mut();
         let subscription = cx.observe(self, move |_, _| {
-            tx.try_send(()).ok();
+            tx.unbounded_send(()).ok();
         });
 
         let duration = if std::env::var("CI").is_ok() {
@@ -3407,7 +3405,7 @@ impl<T: Entity> ModelHandle<T> {
         };
 
         async move {
-            let notification = crate::util::timeout(duration, rx.recv())
+            let notification = crate::util::timeout(duration, rx.next())
                 .await
                 .expect("next notification timed out");
             drop(subscription);
@@ -3420,12 +3418,10 @@ impl<T: Entity> ModelHandle<T> {
     where
         T::Event: Clone,
     {
-        use postage::prelude::{Sink as _, Stream as _};
-
-        let (mut tx, mut rx) = postage::mpsc::channel(1);
+        let (tx, mut rx) = futures::channel::mpsc::unbounded();
         let mut cx = cx.cx.borrow_mut();
         let subscription = cx.subscribe(self, move |_, event, _| {
-            tx.blocking_send(event.clone()).ok();
+            tx.unbounded_send(event.clone()).ok();
         });
 
         let duration = if std::env::var("CI").is_ok() {
@@ -3434,8 +3430,9 @@ impl<T: Entity> ModelHandle<T> {
             Duration::from_secs(1)
         };
 
+        cx.foreground.start_waiting();
         async move {
-            let event = crate::util::timeout(duration, rx.recv())
+            let event = crate::util::timeout(duration, rx.next())
                 .await
                 .expect("next event timed out");
             drop(subscription);
@@ -3449,22 +3446,20 @@ impl<T: Entity> ModelHandle<T> {
         cx: &TestAppContext,
         mut predicate: impl FnMut(&T, &AppContext) -> bool,
     ) -> impl Future<Output = ()> {
-        use postage::prelude::{Sink as _, Stream as _};
-
-        let (tx, mut rx) = postage::mpsc::channel(1024);
+        let (tx, mut rx) = futures::channel::mpsc::unbounded();
 
         let mut cx = cx.cx.borrow_mut();
         let subscriptions = (
             cx.observe(self, {
-                let mut tx = tx.clone();
+                let tx = tx.clone();
                 move |_, _| {
-                    tx.blocking_send(()).ok();
+                    tx.unbounded_send(()).ok();
                 }
             }),
             cx.subscribe(self, {
-                let mut tx = tx.clone();
+                let tx = tx.clone();
                 move |_, _, _| {
-                    tx.blocking_send(()).ok();
+                    tx.unbounded_send(()).ok();
                 }
             }),
         );
@@ -3495,7 +3490,7 @@ impl<T: Entity> ModelHandle<T> {
                     }
 
                     cx.borrow().foreground().start_waiting();
-                    rx.recv()
+                    rx.next()
                         .await
                         .expect("model dropped with pending condition");
                     cx.borrow().foreground().finish_waiting();

crates/project/src/project.rs 🔗

@@ -74,7 +74,6 @@ pub struct Project {
     client_state: ProjectClientState,
     collaborators: HashMap<PeerId, Collaborator>,
     subscriptions: Vec<client::Subscription>,
-    language_servers_with_diagnostics_running: isize,
     opened_buffer: (Rc<RefCell<watch::Sender<()>>>, watch::Receiver<()>),
     shared_buffers: HashMap<PeerId, HashSet<u64>>,
     loading_buffers: HashMap<
@@ -330,7 +329,6 @@ impl Project {
                 user_store,
                 fs,
                 next_entry_id: Default::default(),
-                language_servers_with_diagnostics_running: 0,
                 language_servers: Default::default(),
                 started_language_servers: Default::default(),
                 language_server_statuses: Default::default(),
@@ -404,7 +402,6 @@ impl Project {
                         .log_err()
                     }),
                 },
-                language_servers_with_diagnostics_running: 0,
                 language_servers: Default::default(),
                 started_language_servers: Default::default(),
                 language_server_settings: Default::default(),
@@ -3496,7 +3493,9 @@ impl Project {
     }
 
     pub fn is_running_disk_based_diagnostics(&self) -> bool {
-        self.language_servers_with_diagnostics_running > 0
+        self.language_server_statuses
+            .values()
+            .any(|status| status.pending_diagnostic_updates > 0)
     }
 
     pub fn diagnostic_summary(&self, cx: &AppContext) -> DiagnosticSummary {
@@ -3524,16 +3523,26 @@ impl Project {
     }
 
     pub fn disk_based_diagnostics_started(&mut self, cx: &mut ModelContext<Self>) {
-        self.language_servers_with_diagnostics_running += 1;
-        if self.language_servers_with_diagnostics_running == 1 {
+        if self
+            .language_server_statuses
+            .values()
+            .map(|status| status.pending_diagnostic_updates)
+            .sum::<isize>()
+            == 1
+        {
             cx.emit(Event::DiskBasedDiagnosticsStarted);
         }
     }
 
     pub fn disk_based_diagnostics_finished(&mut self, cx: &mut ModelContext<Self>) {
         cx.emit(Event::DiskBasedDiagnosticsUpdated);
-        self.language_servers_with_diagnostics_running -= 1;
-        if self.language_servers_with_diagnostics_running == 0 {
+        if self
+            .language_server_statuses
+            .values()
+            .map(|status| status.pending_diagnostic_updates)
+            .sum::<isize>()
+            == 0
+        {
             cx.emit(Event::DiskBasedDiagnosticsFinished);
         }
     }
@@ -5453,6 +5462,80 @@ mod tests {
         });
     }
 
+    #[gpui::test]
+    async fn test_restarting_server_with_diagnostics_running(cx: &mut gpui::TestAppContext) {
+        cx.foreground().forbid_parking();
+
+        let progress_token = "the-progress-token";
+        let mut language = Language::new(
+            LanguageConfig {
+                path_suffixes: vec!["rs".to_string()],
+                ..Default::default()
+            },
+            None,
+        );
+        let mut fake_servers = language.set_fake_lsp_adapter(FakeLspAdapter {
+            disk_based_diagnostics_sources: &["disk"],
+            disk_based_diagnostics_progress_token: Some(progress_token),
+            ..Default::default()
+        });
+
+        let fs = FakeFs::new(cx.background());
+        fs.insert_tree("/dir", json!({ "a.rs": "" })).await;
+
+        let project = Project::test(fs, cx);
+        project.update(cx, |project, _| project.languages.add(Arc::new(language)));
+
+        let worktree_id = project
+            .update(cx, |project, cx| {
+                project.find_or_create_local_worktree("/dir", true, cx)
+            })
+            .await
+            .unwrap()
+            .0
+            .read_with(cx, |tree, _| tree.id());
+
+        let buffer = project
+            .update(cx, |project, cx| {
+                project.open_buffer((worktree_id, "a.rs"), cx)
+            })
+            .await
+            .unwrap();
+
+        // Simulate diagnostics starting to update.
+        let mut fake_server = fake_servers.next().await.unwrap();
+        fake_server.start_progress(progress_token).await;
+
+        // Restart the server before the diagnostics finish updating.
+        project.update(cx, |project, cx| {
+            project.restart_language_servers_for_buffers([buffer], cx);
+        });
+        let mut events = subscribe(&project, cx);
+
+        // Simulate the newly started server sending more diagnostics.
+        let mut fake_server = fake_servers.next().await.unwrap();
+        fake_server.start_progress(progress_token).await;
+        assert_eq!(
+            events.next().await.unwrap(),
+            Event::DiskBasedDiagnosticsStarted
+        );
+
+        // All diagnostics are considered done, despite the old server's diagnostic
+        // task never completing.
+        fake_server.end_progress(progress_token).await;
+        assert_eq!(
+            events.next().await.unwrap(),
+            Event::DiskBasedDiagnosticsUpdated
+        );
+        assert_eq!(
+            events.next().await.unwrap(),
+            Event::DiskBasedDiagnosticsFinished
+        );
+        project.read_with(cx, |project, _| {
+            assert!(!project.is_running_disk_based_diagnostics());
+        });
+    }
+
     #[gpui::test]
     async fn test_transforming_diagnostics(cx: &mut gpui::TestAppContext) {
         cx.foreground().forbid_parking();