Apply additional edits when confirming a completion

Antonio Scandurra and Nathan Sobo created

Co-Authored-By: Nathan Sobo <nathan@zed.dev>

Change summary

crates/editor/src/editor.rs       |  53 ++++++++++++---
crates/editor/src/multi_buffer.rs |  30 +++++++++
crates/language/src/buffer.rs     | 105 +++++++++++++++++++++++---------
3 files changed, 145 insertions(+), 43 deletions(-)

Detailed changes

crates/editor/src/editor.rs 🔗

@@ -8,6 +8,7 @@ mod multi_buffer;
 mod test;
 
 use aho_corasick::AhoCorasick;
+use anyhow::Result;
 use clock::ReplicaId;
 use collections::{BTreeMap, HashMap, HashSet};
 pub use display_map::DisplayPoint;
@@ -295,7 +296,9 @@ pub fn init(cx: &mut MutableAppContext, path_openers: &mut Vec<Box<dyn PathOpene
     cx.add_action(Editor::unfold);
     cx.add_action(Editor::fold_selected_ranges);
     cx.add_action(Editor::show_completions);
-    cx.add_action(Editor::confirm_completion);
+    cx.add_action(|editor: &mut Editor, _: &ConfirmCompletion, cx| {
+        editor.confirm_completion(cx).detach_and_log_err(cx);
+    });
 }
 
 trait SelectionExt {
@@ -1645,21 +1648,20 @@ impl Editor {
         self.completion_state.take()
     }
 
-    fn confirm_completion(&mut self, _: &ConfirmCompletion, cx: &mut ViewContext<Self>) {
+    fn confirm_completion(&mut self, cx: &mut ViewContext<Self>) -> Task<Result<()>> {
         if let Some(completion_state) = self.hide_completions(cx) {
             if let Some(completion) = completion_state
-                .completions
+                .matches
                 .get(completion_state.selected_item)
+                .and_then(|mat| completion_state.completions.get(mat.candidate_id))
             {
-                self.buffer.update(cx, |buffer, cx| {
-                    buffer.edit_with_autoindent(
-                        [completion.old_range.clone()],
-                        completion.new_text.clone(),
-                        cx,
-                    );
-                })
+                return self.buffer.update(cx, |buffer, cx| {
+                    buffer.apply_completion(completion.clone(), cx)
+                });
             }
         }
+
+        Task::ready(Ok(()))
     }
 
     pub fn has_completions(&self) -> bool {
@@ -6654,9 +6656,9 @@ mod tests {
 
         editor.next_notification(&cx).await;
 
-        editor.update(&mut cx, |editor, cx| {
+        let apply_additional_edits = editor.update(&mut cx, |editor, cx| {
             editor.move_down(&MoveDown, cx);
-            editor.confirm_completion(&ConfirmCompletion, cx);
+            let apply_additional_edits = editor.confirm_completion(cx);
             assert_eq!(
                 editor.text(cx),
                 "
@@ -6666,7 +6668,34 @@ mod tests {
                 "
                 .unindent()
             );
+            apply_additional_edits
         });
+        let (id, _) = fake
+            .receive_request::<lsp::request::ResolveCompletionItem>()
+            .await;
+        fake.respond(
+            id,
+            lsp::CompletionItem {
+                additional_text_edits: Some(vec![lsp::TextEdit::new(
+                    lsp::Range::new(lsp::Position::new(2, 5), lsp::Position::new(2, 5)),
+                    "\nadditional edit".to_string(),
+                )]),
+                ..Default::default()
+            },
+        )
+        .await;
+
+        apply_additional_edits.await.unwrap();
+        assert_eq!(
+            editor.read_with(&cx, |editor, cx| editor.text(cx)),
+            "
+                    one.second_completion
+                    two
+                    three
+                    additional edit
+                "
+            .unindent()
+        );
     }
 
     #[gpui::test]

crates/editor/src/multi_buffer.rs 🔗

@@ -1,7 +1,7 @@
 mod anchor;
 
 pub use anchor::{Anchor, AnchorRangeExt};
-use anyhow::Result;
+use anyhow::{anyhow, Result};
 use clock::ReplicaId;
 use collections::{HashMap, HashSet};
 use gpui::{AppContext, Entity, ModelContext, ModelHandle, Task};
@@ -929,6 +929,34 @@ impl MultiBuffer {
         }
     }
 
+    pub fn apply_completion(
+        &self,
+        completion: Completion<Anchor>,
+        cx: &mut ModelContext<Self>,
+    ) -> Task<Result<()>> {
+        let buffer = if let Some(buffer) = self
+            .buffers
+            .borrow()
+            .get(&completion.old_range.start.buffer_id)
+        {
+            buffer.buffer.clone()
+        } else {
+            return Task::ready(Err(anyhow!("completion cannot be applied to any buffer")));
+        };
+
+        buffer.update(cx, |buffer, cx| {
+            buffer.apply_completion(
+                Completion {
+                    old_range: completion.old_range.start.text_anchor
+                        ..completion.old_range.end.text_anchor,
+                    new_text: completion.new_text,
+                    lsp_completion: completion.lsp_completion,
+                },
+                cx,
+            )
+        })
+    }
+
     pub fn language<'a>(&self, cx: &'a AppContext) -> Option<&'a Arc<Language>> {
         self.buffers
             .borrow()

crates/language/src/buffer.rs 🔗

@@ -114,7 +114,7 @@ pub struct Diagnostic {
     pub is_disk_based: bool,
 }
 
-#[derive(Debug)]
+#[derive(Clone, Debug)]
 pub struct Completion<T> {
     pub old_range: Range<T>,
     pub new_text: String,
@@ -165,6 +165,10 @@ pub enum Event {
 pub trait File {
     fn as_local(&self) -> Option<&dyn LocalFile>;
 
+    fn is_local(&self) -> bool {
+        self.as_local().is_some()
+    }
+
     fn mtime(&self) -> SystemTime;
 
     /// Returns the path of this file relative to the worktree's root directory.
@@ -567,21 +571,7 @@ impl Buffer {
                 if let Some(edits) = edits {
                     this.update(&mut cx, |this, cx| {
                         if this.version == version {
-                            for edit in &edits {
-                                let range = range_from_lsp(edit.range);
-                                if this.clip_point_utf16(range.start, Bias::Left) != range.start
-                                    || this.clip_point_utf16(range.end, Bias::Left) != range.end
-                                {
-                                    return Err(anyhow!(
-                                        "invalid formatting edits received from language server"
-                                    ));
-                                }
-                            }
-
-                            for edit in edits.into_iter().rev() {
-                                this.edit([range_from_lsp(edit.range)], edit.new_text, cx);
-                            }
-                            Ok(())
+                            this.apply_lsp_edits(edits, cx)
                         } else {
                             Err(anyhow!("buffer edited since starting to format"))
                         }
@@ -1390,13 +1380,6 @@ impl Buffer {
         self.edit_internal(ranges_iter, new_text, true, cx)
     }
 
-    /*
-    impl Buffer
-        pub fn edit
-        pub fn edit_internal
-        pub fn edit_with_autoindent
-    */
-
     pub fn edit_internal<I, S, T>(
         &mut self,
         ranges_iter: I,
@@ -1485,6 +1468,29 @@ impl Buffer {
         self.send_operation(Operation::Buffer(text::Operation::Edit(edit)), cx);
     }
 
+    fn apply_lsp_edits(
+        &mut self,
+        edits: Vec<lsp::TextEdit>,
+        cx: &mut ModelContext<Self>,
+    ) -> Result<()> {
+        for edit in &edits {
+            let range = range_from_lsp(edit.range);
+            if self.clip_point_utf16(range.start, Bias::Left) != range.start
+                || self.clip_point_utf16(range.end, Bias::Left) != range.end
+            {
+                return Err(anyhow!(
+                    "invalid formatting edits received from language server"
+                ));
+            }
+        }
+
+        for edit in edits.into_iter().rev() {
+            self.edit([range_from_lsp(edit.range)], edit.new_text, cx);
+        }
+
+        Ok(())
+    }
+
     fn did_edit(
         &mut self,
         old_version: &clock::Global,
@@ -1752,13 +1758,17 @@ impl Buffer {
                             },
                         };
 
-                        let old_range = this.anchor_before(old_range.start)..this.anchor_after(old_range.end);
-
-                        Some(Completion {
-                            old_range,
-                            new_text,
-                            lsp_completion,
-                        })
+                        let clipped_start = this.clip_point_utf16(old_range.start, Bias::Left);
+                        let clipped_end = this.clip_point_utf16(old_range.end, Bias::Left) ;
+                        if clipped_start == old_range.start && clipped_end == old_range.end {
+                            Some(Completion {
+                                old_range: this.anchor_before(old_range.start)..this.anchor_after(old_range.end),
+                                new_text,
+                                lsp_completion,
+                            })
+                        } else {
+                            None
+                        }
                     }).collect())
                 })
             })
@@ -1766,6 +1776,41 @@ impl Buffer {
             Task::ready(Ok(Default::default()))
         }
     }
+
+    pub fn apply_completion(
+        &mut self,
+        completion: Completion<Anchor>,
+        cx: &mut ModelContext<Self>,
+    ) -> Task<Result<()>> {
+        self.edit_with_autoindent([completion.old_range], completion.new_text.clone(), cx);
+
+        let file = if let Some(file) = self.file.as_ref() {
+            file
+        } else {
+            return Task::ready(Ok(Default::default()));
+        };
+        if file.is_local() {
+            let server = if let Some(lang) = self.language_server.as_ref() {
+                lang.server.clone()
+            } else {
+                return Task::ready(Ok(Default::default()));
+            };
+
+            cx.spawn(|this, mut cx| async move {
+                let resolved_completion = server
+                    .request::<lsp::request::ResolveCompletionItem>(completion.lsp_completion)
+                    .await?;
+                if let Some(additional_edits) = resolved_completion.additional_text_edits {
+                    this.update(&mut cx, |this, cx| {
+                        this.apply_lsp_edits(additional_edits, cx)
+                    })?;
+                }
+                Ok::<_, anyhow::Error>(())
+            })
+        } else {
+            return Task::ready(Ok(Default::default()));
+        }
+    }
 }
 
 #[cfg(any(test, feature = "test-support"))]