remote dev: Allow canceling language server work in editor (#19946)

Thorsten Ball created

Release Notes:

- Added ability to cancel language server work in remote development.

Demo:



https://github.com/user-attachments/assets/c9ca91a5-617f-4886-a458-87c563c5a247

Change summary

crates/editor/src/editor.rs                      |   2 
crates/project/src/lsp_store.rs                  | 194 ++++++++++++-----
crates/proto/proto/zed.proto                     |  24 ++
crates/proto/src/proto.rs                        |   5 
crates/remote_server/src/remote_editing_tests.rs | 166 +++++++++++++++
5 files changed, 325 insertions(+), 66 deletions(-)

Detailed changes

crates/editor/src/editor.rs 🔗

@@ -10460,7 +10460,7 @@ impl Editor {
 
     fn cancel_language_server_work(
         &mut self,
-        _: &CancelLanguageServerWork,
+        _: &actions::CancelLanguageServerWork,
         cx: &mut ViewContext<Self>,
     ) {
         if let Some(project) = self.project.clone() {

crates/project/src/lsp_store.rs 🔗

@@ -787,6 +787,7 @@ impl LspStore {
     pub fn init(client: &AnyProtoClient) {
         client.add_model_request_handler(Self::handle_multi_lsp_query);
         client.add_model_request_handler(Self::handle_restart_language_servers);
+        client.add_model_request_handler(Self::handle_cancel_language_server_work);
         client.add_model_message_handler(Self::handle_start_language_server);
         client.add_model_message_handler(Self::handle_update_language_server);
         client.add_model_message_handler(Self::handle_language_server_log);
@@ -4118,7 +4119,7 @@ impl LspStore {
                         LanguageServerProgress {
                             title: payload.title,
                             is_disk_based_diagnostics_progress: false,
-                            is_cancellable: false,
+                            is_cancellable: payload.is_cancellable.unwrap_or(false),
                             message: payload.message,
                             percentage: payload.percentage.map(|p| p as usize),
                             last_update_at: cx.background_executor().now(),
@@ -4134,7 +4135,7 @@ impl LspStore {
                         LanguageServerProgress {
                             title: None,
                             is_disk_based_diagnostics_progress: false,
-                            is_cancellable: false,
+                            is_cancellable: payload.is_cancellable.unwrap_or(false),
                             message: payload.message,
                             percentage: payload.percentage.map(|p| p as usize),
                             last_update_at: cx.background_executor().now(),
@@ -4635,6 +4636,7 @@ impl LspStore {
                                 token,
                                 message: report.message,
                                 percentage: report.percentage,
+                                is_cancellable: report.cancellable,
                             },
                         ),
                     })
@@ -4668,6 +4670,7 @@ impl LspStore {
                 title: progress.title,
                 message: progress.message,
                 percentage: progress.percentage.map(|p| p as u32),
+                is_cancellable: Some(progress.is_cancellable),
             }),
         })
     }
@@ -4698,6 +4701,9 @@ impl LspStore {
                         if progress.percentage.is_some() {
                             entry.percentage = progress.percentage;
                         }
+                        if progress.is_cancellable != entry.is_cancellable {
+                            entry.is_cancellable = progress.is_cancellable;
+                        }
                         cx.notify();
                         return true;
                     }
@@ -5168,22 +5174,52 @@ impl LspStore {
         mut cx: AsyncAppContext,
     ) -> Result<proto::Ack> {
         this.update(&mut cx, |this, cx| {
-            let buffers: Vec<_> = envelope
-                .payload
-                .buffer_ids
-                .into_iter()
-                .flat_map(|buffer_id| {
-                    this.buffer_store
-                        .read(cx)
-                        .get(BufferId::new(buffer_id).log_err()?)
-                })
-                .collect();
-            this.restart_language_servers_for_buffers(buffers, cx)
+            let buffers = this.buffer_ids_to_buffers(envelope.payload.buffer_ids.into_iter(), cx);
+            this.restart_language_servers_for_buffers(buffers, cx);
+        })?;
+
+        Ok(proto::Ack {})
+    }
+
+    pub async fn handle_cancel_language_server_work(
+        this: Model<Self>,
+        envelope: TypedEnvelope<proto::CancelLanguageServerWork>,
+        mut cx: AsyncAppContext,
+    ) -> Result<proto::Ack> {
+        this.update(&mut cx, |this, cx| {
+            if let Some(work) = envelope.payload.work {
+                match work {
+                    proto::cancel_language_server_work::Work::Buffers(buffers) => {
+                        let buffers =
+                            this.buffer_ids_to_buffers(buffers.buffer_ids.into_iter(), cx);
+                        this.cancel_language_server_work_for_buffers(buffers, cx);
+                    }
+                    proto::cancel_language_server_work::Work::LanguageServerWork(work) => {
+                        let server_id = LanguageServerId::from_proto(work.language_server_id);
+                        this.cancel_language_server_work(server_id, work.token, cx);
+                    }
+                }
+            }
         })?;
 
         Ok(proto::Ack {})
     }
 
+    fn buffer_ids_to_buffers(
+        &mut self,
+        buffer_ids: impl Iterator<Item = u64>,
+        cx: &mut ModelContext<Self>,
+    ) -> Vec<Model<Buffer>> {
+        buffer_ids
+            .into_iter()
+            .flat_map(|buffer_id| {
+                self.buffer_store
+                    .read(cx)
+                    .get(BufferId::new(buffer_id).log_err()?)
+            })
+            .collect::<Vec<_>>()
+    }
+
     async fn handle_apply_additional_edits_for_completion(
         this: Model<Self>,
         envelope: TypedEnvelope<proto::ApplyCompletionAdditionalEdits>,
@@ -6728,16 +6764,89 @@ impl LspStore {
         buffers: impl IntoIterator<Item = Model<Buffer>>,
         cx: &mut ModelContext<Self>,
     ) {
-        let servers = buffers
-            .into_iter()
-            .flat_map(|buffer| {
-                self.language_server_ids_for_buffer(buffer.read(cx), cx)
-                    .into_iter()
-            })
-            .collect::<HashSet<_>>();
+        if let Some((client, project_id)) = self.upstream_client() {
+            let request = client.request(proto::CancelLanguageServerWork {
+                project_id,
+                work: Some(proto::cancel_language_server_work::Work::Buffers(
+                    proto::cancel_language_server_work::Buffers {
+                        buffer_ids: buffers
+                            .into_iter()
+                            .map(|b| b.read(cx).remote_id().to_proto())
+                            .collect(),
+                    },
+                )),
+            });
+            cx.background_executor()
+                .spawn(request)
+                .detach_and_log_err(cx);
+        } else {
+            let servers = buffers
+                .into_iter()
+                .flat_map(|buffer| {
+                    self.language_server_ids_for_buffer(buffer.read(cx), cx)
+                        .into_iter()
+                })
+                .collect::<HashSet<_>>();
+
+            for server_id in servers {
+                self.cancel_language_server_work(server_id, None, cx);
+            }
+        }
+    }
+
+    pub(crate) fn cancel_language_server_work(
+        &mut self,
+        server_id: LanguageServerId,
+        token_to_cancel: Option<String>,
+        cx: &mut ModelContext<Self>,
+    ) {
+        if let Some(local) = self.as_local() {
+            let status = self.language_server_statuses.get(&server_id);
+            let server = local.language_servers.get(&server_id);
+            if let Some((LanguageServerState::Running { server, .. }, status)) = server.zip(status)
+            {
+                for (token, progress) in &status.pending_work {
+                    if let Some(token_to_cancel) = token_to_cancel.as_ref() {
+                        if token != token_to_cancel {
+                            continue;
+                        }
+                    }
+                    if progress.is_cancellable {
+                        server
+                            .notify::<lsp::notification::WorkDoneProgressCancel>(
+                                WorkDoneProgressCancelParams {
+                                    token: lsp::NumberOrString::String(token.clone()),
+                                },
+                            )
+                            .ok();
+                    }
 
-        for server_id in servers {
-            self.cancel_language_server_work(server_id, None, cx);
+                    if progress.is_cancellable {
+                        server
+                            .notify::<lsp::notification::WorkDoneProgressCancel>(
+                                WorkDoneProgressCancelParams {
+                                    token: lsp::NumberOrString::String(token.clone()),
+                                },
+                            )
+                            .ok();
+                    }
+                }
+            }
+        } else if let Some((client, project_id)) = self.upstream_client() {
+            let request = client.request(proto::CancelLanguageServerWork {
+                project_id,
+                work: Some(
+                    proto::cancel_language_server_work::Work::LanguageServerWork(
+                        proto::cancel_language_server_work::LanguageServerWork {
+                            language_server_id: server_id.to_proto(),
+                            token: token_to_cancel,
+                        },
+                    ),
+                ),
+            });
+            cx.background_executor()
+                .spawn(request)
+                .detach_and_log_err(cx);
         }
     }
 
@@ -6868,47 +6977,6 @@ impl LspStore {
         }
     }
 
-    pub(crate) fn cancel_language_server_work(
-        &mut self,
-        server_id: LanguageServerId,
-        token_to_cancel: Option<String>,
-        _cx: &mut ModelContext<Self>,
-    ) {
-        let Some(local) = self.as_local() else {
-            return;
-        };
-        let status = self.language_server_statuses.get(&server_id);
-        let server = local.language_servers.get(&server_id);
-        if let Some((LanguageServerState::Running { server, .. }, status)) = server.zip(status) {
-            for (token, progress) in &status.pending_work {
-                if let Some(token_to_cancel) = token_to_cancel.as_ref() {
-                    if token != token_to_cancel {
-                        continue;
-                    }
-                }
-                if progress.is_cancellable {
-                    server
-                        .notify::<lsp::notification::WorkDoneProgressCancel>(
-                            WorkDoneProgressCancelParams {
-                                token: lsp::NumberOrString::String(token.clone()),
-                            },
-                        )
-                        .ok();
-                }
-
-                if progress.is_cancellable {
-                    server
-                        .notify::<lsp::notification::WorkDoneProgressCancel>(
-                            WorkDoneProgressCancelParams {
-                                token: lsp::NumberOrString::String(token.clone()),
-                            },
-                        )
-                        .ok();
-                }
-            }
-        }
-    }
-
     pub fn wait_for_remote_buffer(
         &mut self,
         id: BufferId,

crates/proto/proto/zed.proto 🔗

@@ -292,7 +292,9 @@ message Envelope {
         GetPathMetadataResponse get_path_metadata_response = 279;
 
         GetPanicFiles get_panic_files = 280;
-        GetPanicFilesResponse get_panic_files_response = 281; // current max
+        GetPanicFilesResponse get_panic_files_response = 281;
+
+        CancelLanguageServerWork cancel_language_server_work = 282; // current max
     }
 
     reserved 87 to 88;
@@ -1257,12 +1259,14 @@ message LspWorkStart {
     optional string title = 4;
     optional string message = 2;
     optional uint32 percentage = 3;
+    optional bool is_cancellable = 5;
 }
 
 message LspWorkProgress {
     string token = 1;
     optional string message = 2;
     optional uint32 percentage = 3;
+    optional bool is_cancellable = 4;
 }
 
 message LspWorkEnd {
@@ -2500,3 +2504,21 @@ message GetPanicFiles {
 message GetPanicFilesResponse {
     repeated string file_contents = 2;
 }
+
+message CancelLanguageServerWork {
+    uint64 project_id = 1;
+
+    oneof work {
+        Buffers buffers = 2;
+        LanguageServerWork language_server_work = 3;
+    }
+
+    message Buffers {
+        repeated uint64 buffer_ids = 2;
+    }
+
+    message LanguageServerWork {
+        uint64 language_server_id = 1;
+        optional string token = 2;
+    }
+}

crates/proto/src/proto.rs 🔗

@@ -366,6 +366,7 @@ messages!(
     (GetPathMetadataResponse, Background),
     (GetPanicFiles, Background),
     (GetPanicFilesResponse, Background),
+    (CancelLanguageServerWork, Foreground),
 );
 
 request_messages!(
@@ -486,7 +487,8 @@ request_messages!(
     (ActivateToolchain, Ack),
     (ActiveToolchain, ActiveToolchainResponse),
     (GetPathMetadata, GetPathMetadataResponse),
-    (GetPanicFiles, GetPanicFilesResponse)
+    (GetPanicFiles, GetPanicFilesResponse),
+    (CancelLanguageServerWork, Ack),
 );
 
 entity_messages!(
@@ -570,6 +572,7 @@ entity_messages!(
     ActivateToolchain,
     ActiveToolchain,
     GetPathMetadata,
+    CancelLanguageServerWork,
 );
 
 entity_messages!(

crates/remote_server/src/remote_editing_tests.rs 🔗

@@ -528,6 +528,172 @@ async fn test_remote_lsp(cx: &mut TestAppContext, server_cx: &mut TestAppContext
     })
 }
 
+#[gpui::test]
+async fn test_remote_cancel_language_server_work(
+    cx: &mut TestAppContext,
+    server_cx: &mut TestAppContext,
+) {
+    let fs = FakeFs::new(server_cx.executor());
+    fs.insert_tree(
+        "/code",
+        json!({
+            "project1": {
+                ".git": {},
+                "README.md": "# project 1",
+                "src": {
+                    "lib.rs": "fn one() -> usize { 1 }"
+                }
+            },
+        }),
+    )
+    .await;
+
+    let (project, headless) = init_test(&fs, cx, server_cx).await;
+
+    fs.insert_tree(
+        "/code/project1/.zed",
+        json!({
+            "settings.json": r#"
+          {
+            "languages": {"Rust":{"language_servers":["rust-analyzer"]}},
+            "lsp": {
+              "rust-analyzer": {
+                "binary": {
+                  "path": "~/.cargo/bin/rust-analyzer"
+                }
+              }
+            }
+          }"#
+        }),
+    )
+    .await;
+
+    cx.update_model(&project, |project, _| {
+        project.languages().register_test_language(LanguageConfig {
+            name: "Rust".into(),
+            matcher: LanguageMatcher {
+                path_suffixes: vec!["rs".into()],
+                ..Default::default()
+            },
+            ..Default::default()
+        });
+        project.languages().register_fake_lsp_adapter(
+            "Rust",
+            FakeLspAdapter {
+                name: "rust-analyzer",
+                ..Default::default()
+            },
+        )
+    });
+
+    let mut fake_lsp = server_cx.update(|cx| {
+        headless.read(cx).languages.register_fake_language_server(
+            LanguageServerName("rust-analyzer".into()),
+            Default::default(),
+            None,
+        )
+    });
+
+    cx.run_until_parked();
+
+    let worktree_id = project
+        .update(cx, |project, cx| {
+            project.find_or_create_worktree("/code/project1", true, cx)
+        })
+        .await
+        .unwrap()
+        .0
+        .read_with(cx, |worktree, _| worktree.id());
+
+    cx.run_until_parked();
+
+    let buffer = project
+        .update(cx, |project, cx| {
+            project.open_buffer((worktree_id, Path::new("src/lib.rs")), cx)
+        })
+        .await
+        .unwrap();
+
+    cx.run_until_parked();
+
+    let mut fake_lsp = fake_lsp.next().await.unwrap();
+
+    // Cancelling all language server work for a given buffer
+    {
+        // Two operations, one cancellable and one not.
+        fake_lsp
+            .start_progress_with(
+                "another-token",
+                lsp::WorkDoneProgressBegin {
+                    cancellable: Some(false),
+                    ..Default::default()
+                },
+            )
+            .await;
+
+        let progress_token = "the-progress-token";
+        fake_lsp
+            .start_progress_with(
+                progress_token,
+                lsp::WorkDoneProgressBegin {
+                    cancellable: Some(true),
+                    ..Default::default()
+                },
+            )
+            .await;
+
+        cx.executor().run_until_parked();
+
+        project.update(cx, |project, cx| {
+            project.cancel_language_server_work_for_buffers([buffer.clone()], cx)
+        });
+
+        cx.executor().run_until_parked();
+
+        // Verify the cancellation was received on the server side
+        let cancel_notification = fake_lsp
+            .receive_notification::<lsp::notification::WorkDoneProgressCancel>()
+            .await;
+        assert_eq!(
+            cancel_notification.token,
+            lsp::NumberOrString::String(progress_token.into())
+        );
+    }
+
+    // Cancelling work by server_id and token
+    {
+        let server_id = fake_lsp.server.server_id();
+        let progress_token = "the-progress-token";
+
+        fake_lsp
+            .start_progress_with(
+                progress_token,
+                lsp::WorkDoneProgressBegin {
+                    cancellable: Some(true),
+                    ..Default::default()
+                },
+            )
+            .await;
+
+        cx.executor().run_until_parked();
+
+        project.update(cx, |project, cx| {
+            project.cancel_language_server_work(server_id, Some(progress_token.into()), cx)
+        });
+
+        cx.executor().run_until_parked();
+
+        // Verify the cancellation was received on the server side
+        let cancel_notification = fake_lsp
+            .receive_notification::<lsp::notification::WorkDoneProgressCancel>()
+            .await;
+        assert_eq!(
+            cancel_notification.token,
+            lsp::NumberOrString::String(progress_token.into())
+        );
+    }
+}
+
 #[gpui::test]
 async fn test_remote_reload(cx: &mut TestAppContext, server_cx: &mut TestAppContext) {
     let fs = FakeFs::new(server_cx.executor());