Merge pull request #1247 from zed-industries/ignore-non-created-progress-tokens

Antonio Scandurra created

Ignore tokens that were not created via `WorkDoneProgressCreate`

Change summary

crates/collab/src/integration_tests.rs |  1 
crates/lsp/src/lsp.rs                  | 20 +++++
crates/project/src/project.rs          | 91 ++++++++++++++++-----------
3 files changed, 72 insertions(+), 40 deletions(-)

Detailed changes

crates/collab/src/integration_tests.rs 🔗

@@ -2995,6 +2995,7 @@ async fn test_language_server_statuses(
         .unwrap();
 
     let fake_language_server = fake_language_servers.next().await.unwrap();
+    fake_language_server.start_progress("the-token").await;
     fake_language_server.notify::<lsp::notification::Progress>(lsp::ProgressParams {
         token: lsp::NumberOrString::String("the-token".to_string()),
         value: lsp::ProgressParamsValue::WorkDone(lsp::WorkDoneProgress::Report(

crates/lsp/src/lsp.rs 🔗

@@ -655,6 +655,14 @@ impl FakeLanguageServer {
         self.server.notify::<T>(params).ok();
     }
 
+    pub async fn request<T>(&self, params: T::Params) -> Result<T::Result>
+    where
+        T: request::Request,
+        T::Result: 'static + Send,
+    {
+        self.server.request::<T>(params).await
+    }
+
     pub async fn receive_notification<T: notification::Notification>(&mut self) -> T::Params {
         self.try_receive_notification::<T>().await.unwrap()
     }
@@ -708,14 +716,20 @@ impl FakeLanguageServer {
         self.server.remove_request_handler::<T>();
     }
 
-    pub async fn start_progress(&mut self, token: impl Into<String>) {
+    pub async fn start_progress(&self, token: impl Into<String>) {
+        let token = token.into();
+        self.request::<request::WorkDoneProgressCreate>(WorkDoneProgressCreateParams {
+            token: NumberOrString::String(token.clone()),
+        })
+        .await
+        .unwrap();
         self.notify::<notification::Progress>(ProgressParams {
-            token: NumberOrString::String(token.into()),
+            token: NumberOrString::String(token),
             value: ProgressParamsValue::WorkDone(WorkDoneProgress::Begin(Default::default())),
         });
     }
 
-    pub async fn end_progress(&mut self, token: impl Into<String>) {
+    pub fn end_progress(&self, token: impl Into<String>) {
         self.notify::<notification::Progress>(ProgressParams {
             token: NumberOrString::String(token.into()),
             value: ProgressParamsValue::WorkDone(WorkDoneProgress::End(Default::default())),

crates/project/src/project.rs 🔗

@@ -178,7 +178,8 @@ pub enum Event {
 pub struct LanguageServerStatus {
     pub name: String,
     pub pending_work: BTreeMap<String, LanguageServerProgress>,
-    pub pending_diagnostic_updates: isize,
+    pub has_pending_diagnostic_updates: bool,
+    progress_tokens: HashSet<String>,
 }
 
 #[derive(Clone, Debug, Serialize)]
@@ -546,7 +547,8 @@ impl Project {
                             LanguageServerStatus {
                                 name: server.name,
                                 pending_work: Default::default(),
-                                pending_diagnostic_updates: 0,
+                                has_pending_diagnostic_updates: false,
+                                progress_tokens: Default::default(),
                             },
                         )
                     })
@@ -2025,8 +2027,23 @@ impl Project {
                     // avoid stalling any language server like `gopls` which waits for a response
                     // to these requests when initializing.
                     language_server
-                        .on_request::<lsp::request::WorkDoneProgressCreate, _, _>(|_, _| async {
-                            Ok(())
+                        .on_request::<lsp::request::WorkDoneProgressCreate, _, _>({
+                            let this = this.downgrade();
+                            move |params, mut cx| async move {
+                                if let Some(this) = this.upgrade(&cx) {
+                                    this.update(&mut cx, |this, _| {
+                                        if let Some(status) =
+                                            this.language_server_statuses.get_mut(&server_id)
+                                        {
+                                            if let lsp::NumberOrString::String(token) = params.token
+                                            {
+                                                status.progress_tokens.insert(token);
+                                            }
+                                        }
+                                    });
+                                }
+                                Ok(())
+                            }
                         })
                         .detach();
                     language_server
@@ -2079,7 +2096,8 @@ impl Project {
                             LanguageServerStatus {
                                 name: language_server.name().to_string(),
                                 pending_work: Default::default(),
-                                pending_diagnostic_updates: 0,
+                                has_pending_diagnostic_updates: false,
+                                progress_tokens: Default::default(),
                             },
                         );
                         language_server
@@ -2291,19 +2309,22 @@ impl Project {
             } else {
                 return;
             };
+
+        if !language_server_status.progress_tokens.contains(&token) {
+            return;
+        }
+
         match progress {
             lsp::WorkDoneProgress::Begin(report) => {
                 if Some(token.as_str()) == disk_based_diagnostics_progress_token {
-                    language_server_status.pending_diagnostic_updates += 1;
-                    if language_server_status.pending_diagnostic_updates == 1 {
-                        self.disk_based_diagnostics_started(server_id, cx);
-                        self.broadcast_language_server_update(
-                            server_id,
-                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdating(
-                                proto::LspDiskBasedDiagnosticsUpdating {},
-                            ),
-                        );
-                    }
+                    language_server_status.has_pending_diagnostic_updates = true;
+                    self.disk_based_diagnostics_started(server_id, cx);
+                    self.broadcast_language_server_update(
+                        server_id,
+                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdating(
+                            proto::LspDiskBasedDiagnosticsUpdating {},
+                        ),
+                    );
                 } else {
                     self.on_lsp_work_start(
                         server_id,
@@ -2350,17 +2371,17 @@ impl Project {
                 }
             }
             lsp::WorkDoneProgress::End(_) => {
+                language_server_status.progress_tokens.remove(&token);
+
                 if Some(token.as_str()) == disk_based_diagnostics_progress_token {
-                    language_server_status.pending_diagnostic_updates -= 1;
-                    if language_server_status.pending_diagnostic_updates == 0 {
-                        self.disk_based_diagnostics_finished(server_id, cx);
-                        self.broadcast_language_server_update(
-                            server_id,
-                            proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
-                                proto::LspDiskBasedDiagnosticsUpdated {},
-                            ),
-                        );
-                    }
+                    language_server_status.has_pending_diagnostic_updates = false;
+                    self.disk_based_diagnostics_finished(server_id, cx);
+                    self.broadcast_language_server_update(
+                        server_id,
+                        proto::update_language_server::Variant::DiskBasedDiagnosticsUpdated(
+                            proto::LspDiskBasedDiagnosticsUpdated {},
+                        ),
+                    );
                 } else {
                     self.on_lsp_work_end(server_id, token.clone(), cx);
                     self.broadcast_language_server_update(
@@ -4216,7 +4237,7 @@ impl Project {
         self.language_server_statuses
             .iter()
             .filter_map(|(id, status)| {
-                if status.pending_diagnostic_updates > 0 {
+                if status.has_pending_diagnostic_updates {
                     Some(*id)
                 } else {
                     None
@@ -4616,7 +4637,8 @@ impl Project {
                 LanguageServerStatus {
                     name: server.name,
                     pending_work: Default::default(),
-                    pending_diagnostic_updates: 0,
+                    has_pending_diagnostic_updates: false,
+                    progress_tokens: Default::default(),
                 },
             );
             cx.notify();
@@ -6431,7 +6453,7 @@ mod tests {
 
         let mut events = subscribe(&project, cx);
 
-        let mut fake_server = fake_servers.next().await.unwrap();
+        let fake_server = fake_servers.next().await.unwrap();
         fake_server.start_progress(progress_token).await;
         assert_eq!(
             events.next().await.unwrap(),
@@ -6440,10 +6462,6 @@ mod tests {
             }
         );
 
-        fake_server.start_progress(progress_token).await;
-        fake_server.end_progress(progress_token).await;
-        fake_server.start_progress(progress_token).await;
-
         fake_server.notify::<lsp::notification::PublishDiagnostics>(
             lsp::PublishDiagnosticsParams {
                 uri: Url::from_file_path("/dir/a.rs").unwrap(),
@@ -6464,8 +6482,7 @@ mod tests {
             }
         );
 
-        fake_server.end_progress(progress_token).await;
-        fake_server.end_progress(progress_token).await;
+        fake_server.end_progress(progress_token);
         assert_eq!(
             events.next().await.unwrap(),
             Event::DiskBasedDiagnosticsFinished {
@@ -6555,7 +6572,7 @@ mod tests {
             .unwrap();
 
         // Simulate diagnostics starting to update.
-        let mut fake_server = fake_servers.next().await.unwrap();
+        let fake_server = fake_servers.next().await.unwrap();
         fake_server.start_progress(progress_token).await;
 
         // Restart the server before the diagnostics finish updating.
@@ -6565,7 +6582,7 @@ mod tests {
         let mut events = subscribe(&project, cx);
 
         // Simulate the newly started server sending more diagnostics.
-        let mut fake_server = fake_servers.next().await.unwrap();
+        let fake_server = fake_servers.next().await.unwrap();
         fake_server.start_progress(progress_token).await;
         assert_eq!(
             events.next().await.unwrap(),
@@ -6584,7 +6601,7 @@ mod tests {
 
         // All diagnostics are considered done, despite the old server's diagnostic
         // task never completing.
-        fake_server.end_progress(progress_token).await;
+        fake_server.end_progress(progress_token);
         assert_eq!(
             events.next().await.unwrap(),
             Event::DiskBasedDiagnosticsFinished {