Use selection instead of just the cursor when fetching code actions

Antonio Scandurra created

Change summary

crates/editor/src/editor.rs   |  19 ++++--
crates/language/src/buffer.rs |   2 
crates/language/src/proto.rs  |  15 +++-
crates/project/src/project.rs | 100 +++++++++++++++++++-----------------
crates/rpc/proto/zed.proto    |   8 +-
crates/text/src/anchor.rs     |   7 ++
6 files changed, 88 insertions(+), 63 deletions(-)

Detailed changes

crates/editor/src/editor.rs 🔗

@@ -2230,12 +2230,17 @@ impl Editor {
 
     fn refresh_code_actions(&mut self, cx: &mut ViewContext<Self>) -> Option<()> {
         let project = self.project.as_ref()?;
-        let new_cursor_position = self.newest_anchor_selection().head();
-        let (buffer, head) = self
-            .buffer
-            .read(cx)
-            .text_anchor_for_position(new_cursor_position, cx)?;
-        let actions = project.update(cx, |project, cx| project.code_actions(&buffer, head, cx));
+        let buffer = self.buffer.read(cx);
+        let newest_selection = self.newest_anchor_selection().clone();
+        let (start_buffer, start) = buffer.text_anchor_for_position(newest_selection.start, cx)?;
+        let (end_buffer, end) = buffer.text_anchor_for_position(newest_selection.end, cx)?;
+        if start_buffer != end_buffer {
+            return None;
+        }
+
+        let actions = project.update(cx, |project, cx| {
+            project.code_actions(&start_buffer, start..end, cx)
+        });
         self.code_actions_task = Some(cx.spawn_weak(|this, mut cx| async move {
             let actions = actions.await;
             if let Some(this) = this.upgrade(&cx) {
@@ -2244,7 +2249,7 @@ impl Editor {
                         if actions.is_empty() {
                             None
                         } else {
-                            Some((buffer, actions.into()))
+                            Some((start_buffer, actions.into()))
                         }
                     });
                     cx.notify();

crates/language/src/buffer.rs 🔗

@@ -123,7 +123,7 @@ pub struct Completion {
 
 #[derive(Clone, Debug)]
 pub struct CodeAction {
-    pub position: Anchor,
+    pub range: Range<Anchor>,
     pub lsp_action: lsp::CodeAction,
 }
 

crates/language/src/proto.rs 🔗

@@ -428,19 +428,24 @@ pub fn deserialize_completion(
 
 pub fn serialize_code_action(action: &CodeAction) -> proto::CodeAction {
     proto::CodeAction {
-        position: Some(serialize_anchor(&action.position)),
+        start: Some(serialize_anchor(&action.range.start)),
+        end: Some(serialize_anchor(&action.range.end)),
         lsp_action: serde_json::to_vec(&action.lsp_action).unwrap(),
     }
 }
 
 pub fn deserialize_code_action(action: proto::CodeAction) -> Result<CodeAction> {
-    let position = action
-        .position
+    let start = action
+        .start
         .and_then(deserialize_anchor)
-        .ok_or_else(|| anyhow!("invalid position"))?;
+        .ok_or_else(|| anyhow!("invalid start"))?;
+    let end = action
+        .end
+        .and_then(deserialize_anchor)
+        .ok_or_else(|| anyhow!("invalid end"))?;
     let lsp_action = serde_json::from_slice(&action.lsp_action)?;
     Ok(CodeAction {
-        position,
+        range: start..end,
         lsp_action,
     })
 }

crates/project/src/project.rs 🔗

@@ -15,9 +15,9 @@ use gpui::{
 use language::{
     point_from_lsp,
     proto::{deserialize_anchor, serialize_anchor},
-    range_from_lsp, Bias, Buffer, CodeAction, Completion, CompletionLabel, Diagnostic,
-    DiagnosticEntry, File as _, Language, LanguageRegistry, PointUtf16, ToLspPosition,
-    ToPointUtf16, Transaction,
+    range_from_lsp, AnchorRangeExt, Bias, Buffer, CodeAction, Completion, CompletionLabel,
+    Diagnostic, DiagnosticEntry, File as _, Language, LanguageRegistry, PointUtf16, ToLspPosition,
+    ToOffset, ToPointUtf16, Transaction,
 };
 use lsp::{DiagnosticSeverity, LanguageServer};
 use postage::{prelude::Stream, watch};
@@ -1474,32 +1474,30 @@ impl Project {
         }
     }
 
-    pub fn code_actions<T: ToPointUtf16>(
+    pub fn code_actions<T: ToOffset>(
         &self,
-        source_buffer_handle: &ModelHandle<Buffer>,
-        position: T,
+        buffer_handle: &ModelHandle<Buffer>,
+        range: Range<T>,
         cx: &mut ModelContext<Self>,
     ) -> Task<Result<Vec<CodeAction>>> {
-        let source_buffer_handle = source_buffer_handle.clone();
-        let source_buffer = source_buffer_handle.read(cx);
-        let buffer_id = source_buffer.remote_id();
+        let buffer_handle = buffer_handle.clone();
+        let buffer = buffer_handle.read(cx);
+        let buffer_id = buffer.remote_id();
         let worktree;
         let buffer_abs_path;
-        if let Some(file) = File::from_dyn(source_buffer.file()) {
+        if let Some(file) = File::from_dyn(buffer.file()) {
             worktree = file.worktree.clone();
             buffer_abs_path = file.as_local().map(|f| f.abs_path(cx));
         } else {
             return Task::ready(Ok(Default::default()));
         };
-
-        let position = position.to_point_utf16(source_buffer);
-        let anchor = source_buffer.anchor_after(position);
+        let range = buffer.anchor_before(range.start)..buffer.anchor_before(range.end);
 
         if worktree.read(cx).as_local().is_some() {
             let buffer_abs_path = buffer_abs_path.unwrap();
             let lang_name;
             let lang_server;
-            if let Some(lang) = source_buffer.language() {
+            if let Some(lang) = buffer.language() {
                 lang_name = lang.name().to_string();
                 if let Some(server) = self
                     .language_servers
@@ -1513,42 +1511,42 @@ impl Project {
                 return Task::ready(Ok(Default::default()));
             }
 
+            let actions =
+                lang_server.request::<lsp::request::CodeActionRequest>(lsp::CodeActionParams {
+                    text_document: lsp::TextDocumentIdentifier::new(
+                        lsp::Url::from_file_path(buffer_abs_path).unwrap(),
+                    ),
+                    range: lsp::Range::new(
+                        range.start.to_point_utf16(buffer).to_lsp_position(),
+                        range.end.to_point_utf16(buffer).to_lsp_position(),
+                    ),
+                    work_done_progress_params: Default::default(),
+                    partial_result_params: Default::default(),
+                    context: lsp::CodeActionContext {
+                        diagnostics: Default::default(),
+                        only: Some(vec![
+                            lsp::CodeActionKind::QUICKFIX,
+                            lsp::CodeActionKind::REFACTOR,
+                            lsp::CodeActionKind::REFACTOR_EXTRACT,
+                        ]),
+                    },
+                });
             cx.foreground().spawn(async move {
-                let actions = lang_server
-                    .request::<lsp::request::CodeActionRequest>(lsp::CodeActionParams {
-                        text_document: lsp::TextDocumentIdentifier::new(
-                            lsp::Url::from_file_path(buffer_abs_path).unwrap(),
-                        ),
-                        range: lsp::Range::new(
-                            position.to_lsp_position(),
-                            position.to_lsp_position(),
-                        ),
-                        work_done_progress_params: Default::default(),
-                        partial_result_params: Default::default(),
-                        context: lsp::CodeActionContext {
-                            diagnostics: Default::default(),
-                            only: Some(vec![
-                                lsp::CodeActionKind::QUICKFIX,
-                                lsp::CodeActionKind::REFACTOR,
-                                lsp::CodeActionKind::REFACTOR_EXTRACT,
-                            ]),
-                        },
-                    })
+                Ok(actions
                     .await?
                     .unwrap_or_default()
                     .into_iter()
                     .filter_map(|entry| {
                         if let lsp::CodeActionOrCommand::CodeAction(lsp_action) = entry {
                             Some(CodeAction {
-                                position: anchor.clone(),
+                                range: range.clone(),
                                 lsp_action,
                             })
                         } else {
                             None
                         }
                     })
-                    .collect();
-                Ok(actions)
+                    .collect())
             })
         } else if let Some(project_id) = self.remote_id() {
             let rpc = self.client.clone();
@@ -1557,7 +1555,8 @@ impl Project {
                     .request(proto::GetCodeActions {
                         project_id,
                         buffer_id,
-                        position: Some(language::proto::serialize_anchor(&anchor)),
+                        start: Some(language::proto::serialize_anchor(&range.start)),
+                        end: Some(language::proto::serialize_anchor(&range.end)),
                     })
                     .await?;
                 response
@@ -1590,25 +1589,29 @@ impl Project {
             } else {
                 return Task::ready(Err(anyhow!("buffer does not have a language server")));
             };
-            let position = action.position.to_point_utf16(buffer).to_lsp_position();
+            let range = action.range.to_point_utf16(buffer);
             let fs = self.fs.clone();
 
             cx.spawn(|this, mut cx| async move {
-                if let Some(range) = action
+                if let Some(lsp_range) = action
                     .lsp_action
                     .data
                     .as_mut()
                     .and_then(|d| d.get_mut("codeActionParams"))
                     .and_then(|d| d.get_mut("range"))
                 {
-                    *range = serde_json::to_value(&lsp::Range::new(position, position)).unwrap();
+                    *lsp_range = serde_json::to_value(&lsp::Range::new(
+                        range.start.to_lsp_position(),
+                        range.end.to_lsp_position(),
+                    ))
+                    .unwrap();
                     action.lsp_action = lang_server
                         .request::<lsp::request::CodeActionResolveRequest>(action.lsp_action)
                         .await?;
                 } else {
                     let actions = this
                         .update(&mut cx, |this, cx| {
-                            this.code_actions(&buffer_handle, action.position.clone(), cx)
+                            this.code_actions(&buffer_handle, action.range, cx)
                         })
                         .await?;
                     action.lsp_action = actions
@@ -2357,18 +2360,23 @@ impl Project {
         mut cx: AsyncAppContext,
     ) -> Result<proto::GetCodeActionsResponse> {
         let sender_id = envelope.original_sender_id()?;
-        let position = envelope
+        let start = envelope
             .payload
-            .position
+            .start
             .and_then(language::proto::deserialize_anchor)
-            .ok_or_else(|| anyhow!("invalid position"))?;
+            .ok_or_else(|| anyhow!("invalid start"))?;
+        let end = envelope
+            .payload
+            .end
+            .and_then(language::proto::deserialize_anchor)
+            .ok_or_else(|| anyhow!("invalid end"))?;
         let code_actions = this.update(&mut cx, |this, cx| {
             let buffer = this
                 .shared_buffers
                 .get(&sender_id)
                 .and_then(|shared_buffers| shared_buffers.get(&envelope.payload.buffer_id).cloned())
                 .ok_or_else(|| anyhow!("unknown buffer id {}", envelope.payload.buffer_id))?;
-            Ok::<_, anyhow::Error>(this.code_actions(&buffer, position, cx))
+            Ok::<_, anyhow::Error>(this.code_actions(&buffer, start..end, cx))
         })?;
 
         Ok(proto::GetCodeActionsResponse {

crates/rpc/proto/zed.proto 🔗

@@ -246,7 +246,8 @@ message Completion {
 message GetCodeActions {
     uint64 project_id = 1;
     uint64 buffer_id = 2;
-    Anchor position = 3;
+    Anchor start = 3;
+    Anchor end = 4;
 }
 
 message GetCodeActionsResponse {
@@ -264,8 +265,9 @@ message ApplyCodeActionResponse {
 }
 
 message CodeAction {
-    Anchor position = 1;
-    bytes lsp_action = 2;
+    Anchor start = 1;
+    Anchor end = 2;
+    bytes lsp_action = 3;
 }
 
 message ProjectTransaction {

crates/text/src/anchor.rs 🔗

@@ -1,5 +1,5 @@
 use super::{Point, ToOffset};
-use crate::{rope::TextDimension, BufferSnapshot};
+use crate::{rope::TextDimension, BufferSnapshot, PointUtf16, ToPointUtf16};
 use anyhow::Result;
 use std::{cmp::Ordering, fmt::Debug, ops::Range};
 use sum_tree::Bias;
@@ -78,6 +78,7 @@ pub trait AnchorRangeExt {
     fn cmp(&self, b: &Range<Anchor>, buffer: &BufferSnapshot) -> Result<Ordering>;
     fn to_offset(&self, content: &BufferSnapshot) -> Range<usize>;
     fn to_point(&self, content: &BufferSnapshot) -> Range<Point>;
+    fn to_point_utf16(&self, content: &BufferSnapshot) -> Range<PointUtf16>;
 }
 
 impl AnchorRangeExt for Range<Anchor> {
@@ -95,4 +96,8 @@ impl AnchorRangeExt for Range<Anchor> {
     fn to_point(&self, content: &BufferSnapshot) -> Range<Point> {
         self.start.summary::<Point>(&content)..self.end.summary::<Point>(&content)
     }
+
+    fn to_point_utf16(&self, content: &BufferSnapshot) -> Range<PointUtf16> {
+        self.start.to_point_utf16(content)..self.end.to_point_utf16(content)
+    }
 }