Use `LspCommand` to handle code actions

Antonio Scandurra created

Change summary

crates/project/src/lsp_command.rs | 206 +++++++++++++++++++++++++++++++-
crates/project/src/project.rs     | 146 ----------------------
2 files changed, 199 insertions(+), 153 deletions(-)

Detailed changes

crates/project/src/lsp_command.rs 🔗

@@ -9,8 +9,8 @@ use gpui::{AppContext, AsyncAppContext, ModelHandle};
 use language::{
     point_from_lsp, point_to_lsp,
     proto::{deserialize_anchor, deserialize_version, serialize_anchor, serialize_version},
-    range_from_lsp, Anchor, Bias, Buffer, CachedLspAdapter, CharKind, Completion, PointUtf16,
-    ToOffset, ToPointUtf16, Unclipped,
+    range_from_lsp, range_to_lsp, Anchor, Bias, Buffer, CachedLspAdapter, CharKind, CodeAction,
+    Completion, OffsetRangeExt, PointUtf16, ToOffset, ToPointUtf16, Unclipped,
 };
 use lsp::{DocumentHighlightKind, LanguageServer, ServerCapabilities};
 use pulldown_cmark::{CodeBlockKind, Event, Options, Parser, Tag};
@@ -29,6 +29,8 @@ pub(crate) trait LspCommand: 'static + Sized {
     fn to_lsp(
         &self,
         path: &Path,
+        buffer: &Buffer,
+        language_server: &Arc<LanguageServer>,
         cx: &AppContext,
     ) -> <Self::LspRequest as lsp::request::Request>::Params;
     async fn response_from_lsp(
@@ -97,6 +99,10 @@ pub(crate) struct GetCompletions {
     pub position: PointUtf16,
 }
 
+pub(crate) struct GetCodeActions {
+    pub range: Range<Anchor>,
+}
+
 #[async_trait(?Send)]
 impl LspCommand for PrepareRename {
     type Response = Option<Range<Anchor>>;
@@ -111,7 +117,13 @@ impl LspCommand for PrepareRename {
         }
     }
 
-    fn to_lsp(&self, path: &Path, _: &AppContext) -> lsp::TextDocumentPositionParams {
+    fn to_lsp(
+        &self,
+        path: &Path,
+        _: &Buffer,
+        _: &Arc<LanguageServer>,
+        _: &AppContext,
+    ) -> lsp::TextDocumentPositionParams {
         lsp::TextDocumentPositionParams {
             text_document: lsp::TextDocumentIdentifier {
                 uri: lsp::Url::from_file_path(path).unwrap(),
@@ -227,7 +239,13 @@ impl LspCommand for PerformRename {
     type LspRequest = lsp::request::Rename;
     type ProtoRequest = proto::PerformRename;
 
-    fn to_lsp(&self, path: &Path, _: &AppContext) -> lsp::RenameParams {
+    fn to_lsp(
+        &self,
+        path: &Path,
+        _: &Buffer,
+        _: &Arc<LanguageServer>,
+        _: &AppContext,
+    ) -> lsp::RenameParams {
         lsp::RenameParams {
             text_document_position: lsp::TextDocumentPositionParams {
                 text_document: lsp::TextDocumentIdentifier {
@@ -338,7 +356,13 @@ impl LspCommand for GetDefinition {
     type LspRequest = lsp::request::GotoDefinition;
     type ProtoRequest = proto::GetDefinition;
 
-    fn to_lsp(&self, path: &Path, _: &AppContext) -> lsp::GotoDefinitionParams {
+    fn to_lsp(
+        &self,
+        path: &Path,
+        _: &Buffer,
+        _: &Arc<LanguageServer>,
+        _: &AppContext,
+    ) -> lsp::GotoDefinitionParams {
         lsp::GotoDefinitionParams {
             text_document_position_params: lsp::TextDocumentPositionParams {
                 text_document: lsp::TextDocumentIdentifier {
@@ -424,7 +448,13 @@ impl LspCommand for GetTypeDefinition {
     type LspRequest = lsp::request::GotoTypeDefinition;
     type ProtoRequest = proto::GetTypeDefinition;
 
-    fn to_lsp(&self, path: &Path, _: &AppContext) -> lsp::GotoTypeDefinitionParams {
+    fn to_lsp(
+        &self,
+        path: &Path,
+        _: &Buffer,
+        _: &Arc<LanguageServer>,
+        _: &AppContext,
+    ) -> lsp::GotoTypeDefinitionParams {
         lsp::GotoTypeDefinitionParams {
             text_document_position_params: lsp::TextDocumentPositionParams {
                 text_document: lsp::TextDocumentIdentifier {
@@ -699,7 +729,13 @@ impl LspCommand for GetReferences {
     type LspRequest = lsp::request::References;
     type ProtoRequest = proto::GetReferences;
 
-    fn to_lsp(&self, path: &Path, _: &AppContext) -> lsp::ReferenceParams {
+    fn to_lsp(
+        &self,
+        path: &Path,
+        _: &Buffer,
+        _: &Arc<LanguageServer>,
+        _: &AppContext,
+    ) -> lsp::ReferenceParams {
         lsp::ReferenceParams {
             text_document_position: lsp::TextDocumentPositionParams {
                 text_document: lsp::TextDocumentIdentifier {
@@ -857,7 +893,13 @@ impl LspCommand for GetDocumentHighlights {
         capabilities.document_highlight_provider.is_some()
     }
 
-    fn to_lsp(&self, path: &Path, _: &AppContext) -> lsp::DocumentHighlightParams {
+    fn to_lsp(
+        &self,
+        path: &Path,
+        _: &Buffer,
+        _: &Arc<LanguageServer>,
+        _: &AppContext,
+    ) -> lsp::DocumentHighlightParams {
         lsp::DocumentHighlightParams {
             text_document_position_params: lsp::TextDocumentPositionParams {
                 text_document: lsp::TextDocumentIdentifier {
@@ -997,7 +1039,13 @@ impl LspCommand for GetHover {
     type LspRequest = lsp::request::HoverRequest;
     type ProtoRequest = proto::GetHover;
 
-    fn to_lsp(&self, path: &Path, _: &AppContext) -> lsp::HoverParams {
+    fn to_lsp(
+        &self,
+        path: &Path,
+        _: &Buffer,
+        _: &Arc<LanguageServer>,
+        _: &AppContext,
+    ) -> lsp::HoverParams {
         lsp::HoverParams {
             text_document_position_params: lsp::TextDocumentPositionParams {
                 text_document: lsp::TextDocumentIdentifier {
@@ -1212,7 +1260,13 @@ impl LspCommand for GetCompletions {
     type LspRequest = lsp::request::Completion;
     type ProtoRequest = proto::GetCompletions;
 
-    fn to_lsp(&self, path: &Path, _: &AppContext) -> lsp::CompletionParams {
+    fn to_lsp(
+        &self,
+        path: &Path,
+        _: &Buffer,
+        _: &Arc<LanguageServer>,
+        _: &AppContext,
+    ) -> lsp::CompletionParams {
         lsp::CompletionParams {
             text_document_position: lsp::TextDocumentPositionParams::new(
                 lsp::TextDocumentIdentifier::new(lsp::Url::from_file_path(path).unwrap()),
@@ -1406,3 +1460,135 @@ impl LspCommand for GetCompletions {
         message.buffer_id
     }
 }
+
+#[async_trait(?Send)]
+impl LspCommand for GetCodeActions {
+    type Response = Vec<CodeAction>;
+    type LspRequest = lsp::request::CodeActionRequest;
+    type ProtoRequest = proto::GetCodeActions;
+
+    fn check_capabilities(&self, capabilities: &ServerCapabilities) -> bool {
+        capabilities.code_action_provider.is_some()
+    }
+
+    fn to_lsp(
+        &self,
+        path: &Path,
+        buffer: &Buffer,
+        language_server: &Arc<LanguageServer>,
+        _: &AppContext,
+    ) -> lsp::CodeActionParams {
+        let relevant_diagnostics = buffer
+            .snapshot()
+            .diagnostics_in_range::<_, usize>(self.range.clone(), false)
+            .map(|entry| entry.to_lsp_diagnostic_stub())
+            .collect();
+        lsp::CodeActionParams {
+            text_document: lsp::TextDocumentIdentifier::new(
+                lsp::Url::from_file_path(path).unwrap(),
+            ),
+            range: range_to_lsp(self.range.to_point_utf16(buffer)),
+            work_done_progress_params: Default::default(),
+            partial_result_params: Default::default(),
+            context: lsp::CodeActionContext {
+                diagnostics: relevant_diagnostics,
+                only: language_server.code_action_kinds(),
+            },
+        }
+    }
+
+    async fn response_from_lsp(
+        self,
+        actions: Option<lsp::CodeActionResponse>,
+        _: ModelHandle<Project>,
+        _: ModelHandle<Buffer>,
+        _: AsyncAppContext,
+    ) -> Result<Vec<CodeAction>> {
+        Ok(actions
+            .unwrap_or_default()
+            .into_iter()
+            .filter_map(|entry| {
+                if let lsp::CodeActionOrCommand::CodeAction(lsp_action) = entry {
+                    Some(CodeAction {
+                        range: self.range.clone(),
+                        lsp_action,
+                    })
+                } else {
+                    None
+                }
+            })
+            .collect())
+    }
+
+    fn to_proto(&self, project_id: u64, buffer: &Buffer) -> proto::GetCodeActions {
+        proto::GetCodeActions {
+            project_id,
+            buffer_id: buffer.remote_id(),
+            start: Some(language::proto::serialize_anchor(&self.range.start)),
+            end: Some(language::proto::serialize_anchor(&self.range.end)),
+            version: serialize_version(&buffer.version()),
+        }
+    }
+
+    async fn from_proto(
+        message: proto::GetCodeActions,
+        _: ModelHandle<Project>,
+        buffer: ModelHandle<Buffer>,
+        mut cx: AsyncAppContext,
+    ) -> Result<Self> {
+        let start = message
+            .start
+            .and_then(language::proto::deserialize_anchor)
+            .ok_or_else(|| anyhow!("invalid start"))?;
+        let end = message
+            .end
+            .and_then(language::proto::deserialize_anchor)
+            .ok_or_else(|| anyhow!("invalid end"))?;
+        buffer
+            .update(&mut cx, |buffer, _| {
+                buffer.wait_for_version(deserialize_version(&message.version))
+            })
+            .await?;
+
+        Ok(Self { range: start..end })
+    }
+
+    fn response_to_proto(
+        code_actions: Vec<CodeAction>,
+        _: &mut Project,
+        _: PeerId,
+        buffer_version: &clock::Global,
+        _: &mut AppContext,
+    ) -> proto::GetCodeActionsResponse {
+        proto::GetCodeActionsResponse {
+            actions: code_actions
+                .iter()
+                .map(language::proto::serialize_code_action)
+                .collect(),
+            version: serialize_version(&buffer_version),
+        }
+    }
+
+    async fn response_from_proto(
+        self,
+        message: proto::GetCodeActionsResponse,
+        _: ModelHandle<Project>,
+        buffer: ModelHandle<Buffer>,
+        mut cx: AsyncAppContext,
+    ) -> Result<Vec<CodeAction>> {
+        buffer
+            .update(&mut cx, |buffer, _| {
+                buffer.wait_for_version(deserialize_version(&message.version))
+            })
+            .await?;
+        message
+            .actions
+            .into_iter()
+            .map(language::proto::deserialize_code_action)
+            .collect()
+    }
+
+    fn buffer_id_from_proto(message: &proto::GetCodeActions) -> u64 {
+        message.buffer_id
+    }
+}

crates/project/src/project.rs 🔗

@@ -409,7 +409,7 @@ impl Project {
         client.add_model_request_handler(Self::handle_reload_buffers);
         client.add_model_request_handler(Self::handle_synchronize_buffers);
         client.add_model_request_handler(Self::handle_format_buffers);
-        client.add_model_request_handler(Self::handle_get_code_actions);
+        client.add_model_request_handler(Self::handle_lsp_command::<GetCodeActions>);
         client.add_model_request_handler(Self::handle_lsp_command::<GetCompletions>);
         client.add_model_request_handler(Self::handle_lsp_command::<GetHover>);
         client.add_model_request_handler(Self::handle_lsp_command::<GetDefinition>);
@@ -3704,106 +3704,9 @@ impl Project {
         range: Range<T>,
         cx: &mut ModelContext<Self>,
     ) -> Task<Result<Vec<CodeAction>>> {
-        let buffer_handle = buffer_handle.clone();
         let buffer = buffer_handle.read(cx);
-        let snapshot = buffer.snapshot();
-        let relevant_diagnostics = snapshot
-            .diagnostics_in_range::<usize, usize>(range.to_offset(&snapshot), false)
-            .map(|entry| entry.to_lsp_diagnostic_stub())
-            .collect();
-        let buffer_id = buffer.remote_id();
-        let worktree;
-        let buffer_abs_path;
-        if let Some(file) = File::from_dyn(buffer.file()) {
-            worktree = file.worktree.clone();
-            buffer_abs_path = file.as_local().map(|f| f.abs_path(cx));
-        } else {
-            return Task::ready(Ok(Vec::new()));
-        };
         let range = buffer.anchor_before(range.start)..buffer.anchor_before(range.end);
-
-        if worktree.read(cx).as_local().is_some() {
-            let buffer_abs_path = buffer_abs_path.unwrap();
-            let lang_server = if let Some((_, server)) = self.language_server_for_buffer(buffer, cx)
-            {
-                server.clone()
-            } else {
-                return Task::ready(Ok(Vec::new()));
-            };
-
-            let lsp_range = range_to_lsp(range.to_point_utf16(buffer));
-            cx.foreground().spawn(async move {
-                if lang_server.capabilities().code_action_provider.is_none() {
-                    return Ok(Vec::new());
-                }
-
-                Ok(lang_server
-                    .request::<lsp::request::CodeActionRequest>(lsp::CodeActionParams {
-                        text_document: lsp::TextDocumentIdentifier::new(
-                            lsp::Url::from_file_path(buffer_abs_path).unwrap(),
-                        ),
-                        range: lsp_range,
-                        work_done_progress_params: Default::default(),
-                        partial_result_params: Default::default(),
-                        context: lsp::CodeActionContext {
-                            diagnostics: relevant_diagnostics,
-                            only: lang_server.code_action_kinds(),
-                        },
-                    })
-                    .await?
-                    .unwrap_or_default()
-                    .into_iter()
-                    .filter_map(|entry| {
-                        if let lsp::CodeActionOrCommand::CodeAction(lsp_action) = entry {
-                            Some(CodeAction {
-                                range: range.clone(),
-                                lsp_action,
-                            })
-                        } else {
-                            None
-                        }
-                    })
-                    .collect())
-            })
-        } else if let Some(project_id) = self.remote_id() {
-            let rpc = self.client.clone();
-            let version = buffer.version();
-            cx.spawn_weak(|this, mut cx| async move {
-                let response = rpc
-                    .request(proto::GetCodeActions {
-                        project_id,
-                        buffer_id,
-                        start: Some(language::proto::serialize_anchor(&range.start)),
-                        end: Some(language::proto::serialize_anchor(&range.end)),
-                        version: serialize_version(&version),
-                    })
-                    .await?;
-
-                if this
-                    .upgrade(&cx)
-                    .ok_or_else(|| anyhow!("project was dropped"))?
-                    .read_with(&cx, |this, _| this.is_read_only())
-                {
-                    return Err(anyhow!(
-                        "failed to get code actions: project was disconnected"
-                    ));
-                } else {
-                    buffer_handle
-                        .update(&mut cx, |buffer, _| {
-                            buffer.wait_for_version(deserialize_version(&response.version))
-                        })
-                        .await?;
-
-                    response
-                        .actions
-                        .into_iter()
-                        .map(language::proto::deserialize_code_action)
-                        .collect()
-                }
-            })
-        } else {
-            Task::ready(Ok(Default::default()))
-        }
+        self.request_lsp(buffer_handle.clone(), GetCodeActions { range }, cx)
     }
 
     pub fn apply_code_action(
@@ -4288,7 +4191,7 @@ impl Project {
                 self.language_server_for_buffer(buffer, cx)
                     .map(|(_, server)| server.clone()),
             ) {
-                let lsp_params = request.to_lsp(&file.abs_path(cx), cx);
+                let lsp_params = request.to_lsp(&file.abs_path(cx), buffer, &language_server, cx);
                 return cx.spawn(|this, cx| async move {
                     if !request.check_capabilities(language_server.capabilities()) {
                         return Ok(Default::default());
@@ -5493,49 +5396,6 @@ impl Project {
         })
     }
 
-    async fn handle_get_code_actions(
-        this: ModelHandle<Self>,
-        envelope: TypedEnvelope<proto::GetCodeActions>,
-        _: Arc<Client>,
-        mut cx: AsyncAppContext,
-    ) -> Result<proto::GetCodeActionsResponse> {
-        let start = envelope
-            .payload
-            .start
-            .and_then(language::proto::deserialize_anchor)
-            .ok_or_else(|| anyhow!("invalid start"))?;
-        let end = envelope
-            .payload
-            .end
-            .and_then(language::proto::deserialize_anchor)
-            .ok_or_else(|| anyhow!("invalid end"))?;
-        let buffer = this.update(&mut cx, |this, cx| {
-            this.opened_buffers
-                .get(&envelope.payload.buffer_id)
-                .and_then(|buffer| buffer.upgrade(cx))
-                .ok_or_else(|| anyhow!("unknown buffer id {}", envelope.payload.buffer_id))
-        })?;
-        buffer
-            .update(&mut cx, |buffer, _| {
-                buffer.wait_for_version(deserialize_version(&envelope.payload.version))
-            })
-            .await?;
-
-        let version = buffer.read_with(&cx, |buffer, _| buffer.version());
-        let code_actions = this.update(&mut cx, |this, cx| {
-            Ok::<_, anyhow::Error>(this.code_actions(&buffer, start..end, cx))
-        })?;
-
-        Ok(proto::GetCodeActionsResponse {
-            actions: code_actions
-                .await?
-                .iter()
-                .map(language::proto::serialize_code_action)
-                .collect(),
-            version: serialize_version(&version),
-        })
-    }
-
     async fn handle_apply_code_action(
         this: ModelHandle<Self>,
         envelope: TypedEnvelope<proto::ApplyCodeAction>,