assistant_context_editor: Fix copy paste regression (#31882)

Bennet Bo Fenner created

Closes #31166

Release Notes:

- Fixed an issue where copying and pasting an assistant response in text
threads would result in duplicate text

Change summary

Cargo.lock                                            |   1 
crates/assistant_context_editor/Cargo.toml            |   1 
crates/assistant_context_editor/src/context_editor.rs | 308 +++++++++---
3 files changed, 227 insertions(+), 83 deletions(-)

Detailed changes

Cargo.lock 🔗

@@ -525,6 +525,7 @@ dependencies = [
  "fuzzy",
  "gpui",
  "indexed_docs",
+ "indoc",
  "language",
  "language_model",
  "languages",

crates/assistant_context_editor/Cargo.toml 🔗

@@ -60,6 +60,7 @@ zed_actions.workspace = true
 zed_llm_client.workspace = true
 
 [dev-dependencies]
+indoc.workspace = true
 language_model = { workspace = true, features = ["test-support"] }
 languages = { workspace = true, features = ["test-support"] }
 pretty_assertions.workspace = true

crates/assistant_context_editor/src/context_editor.rs 🔗

@@ -1646,34 +1646,35 @@ impl ContextEditor {
         let context = self.context.read(cx);
 
         let mut text = String::new();
-        for message in context.messages(cx) {
-            if message.offset_range.start >= selection.range().end {
-                break;
-            } else if message.offset_range.end >= selection.range().start {
-                let range = cmp::max(message.offset_range.start, selection.range().start)
-                    ..cmp::min(message.offset_range.end, selection.range().end);
-                if range.is_empty() {
-                    let snapshot = context.buffer().read(cx).snapshot();
-                    let point = snapshot.offset_to_point(range.start);
-                    selection.start = snapshot.point_to_offset(Point::new(point.row, 0));
-                    selection.end = snapshot.point_to_offset(cmp::min(
-                        Point::new(point.row + 1, 0),
-                        snapshot.max_point(),
-                    ));
-                    for chunk in context.buffer().read(cx).text_for_range(selection.range()) {
-                        text.push_str(chunk);
-                    }
-                } else {
-                    for chunk in context.buffer().read(cx).text_for_range(range) {
-                        text.push_str(chunk);
-                    }
-                    if message.offset_range.end < selection.range().end {
-                        text.push('\n');
+
+        // If selection is empty, we want to copy the entire line
+        if selection.range().is_empty() {
+            let snapshot = context.buffer().read(cx).snapshot();
+            let point = snapshot.offset_to_point(selection.range().start);
+            selection.start = snapshot.point_to_offset(Point::new(point.row, 0));
+            selection.end = snapshot
+                .point_to_offset(cmp::min(Point::new(point.row + 1, 0), snapshot.max_point()));
+            for chunk in context.buffer().read(cx).text_for_range(selection.range()) {
+                text.push_str(chunk);
+            }
+        } else {
+            for message in context.messages(cx) {
+                if message.offset_range.start >= selection.range().end {
+                    break;
+                } else if message.offset_range.end >= selection.range().start {
+                    let range = cmp::max(message.offset_range.start, selection.range().start)
+                        ..cmp::min(message.offset_range.end, selection.range().end);
+                    if !range.is_empty() {
+                        for chunk in context.buffer().read(cx).text_for_range(range) {
+                            text.push_str(chunk);
+                        }
+                        if message.offset_range.end < selection.range().end {
+                            text.push('\n');
+                        }
                     }
                 }
             }
         }
-
         (text, CopyMetadata { creases }, vec![selection])
     }
 
@@ -3264,74 +3265,92 @@ mod tests {
     use super::*;
     use fs::FakeFs;
     use gpui::{App, TestAppContext, VisualTestContext};
+    use indoc::indoc;
     use language::{Buffer, LanguageRegistry};
+    use pretty_assertions::assert_eq;
     use prompt_store::PromptBuilder;
+    use text::OffsetRangeExt;
     use unindent::Unindent;
     use util::path;
 
     #[gpui::test]
-    async fn test_copy_paste_no_selection(cx: &mut TestAppContext) {
-        cx.update(init_test);
-
-        let fs = FakeFs::new(cx.executor());
-        let registry = Arc::new(LanguageRegistry::test(cx.executor()));
-        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
-        let context = cx.new(|cx| {
-            AssistantContext::local(
-                registry,
-                None,
-                None,
-                prompt_builder.clone(),
-                Arc::new(SlashCommandWorkingSet::default()),
-                cx,
-            )
-        });
-        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
-        let window = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx));
-        let workspace = window.root(cx).unwrap();
-        let cx = &mut VisualTestContext::from_window(*window, cx);
-
-        let context_editor = window
-            .update(cx, |_, window, cx| {
-                cx.new(|cx| {
-                    ContextEditor::for_context(
-                        context,
-                        fs,
-                        workspace.downgrade(),
-                        project,
-                        None,
-                        window,
-                        cx,
-                    )
-                })
-            })
-            .unwrap();
-
-        context_editor.update_in(cx, |context_editor, window, cx| {
-            context_editor.editor.update(cx, |editor, cx| {
-                editor.set_text("abc\ndef\nghi", window, cx);
-                editor.move_to_beginning(&Default::default(), window, cx);
-            })
-        });
-
-        context_editor.update_in(cx, |context_editor, window, cx| {
-            context_editor.editor.update(cx, |editor, cx| {
-                editor.copy(&Default::default(), window, cx);
-                editor.paste(&Default::default(), window, cx);
+    async fn test_copy_paste_whole_message(cx: &mut TestAppContext) {
+        let (context, context_editor, mut cx) = setup_context_editor_text(vec![
+            (Role::User, "What is the Zed editor?"),
+            (
+                Role::Assistant,
+                "Zed is a modern, high-performance code editor designed from the ground up for speed and collaboration.",
+            ),
+            (Role::User, ""),
+        ],cx).await;
+
+        // Select & Copy whole user message
+        assert_copy_paste_context_editor(
+            &context_editor,
+            message_range(&context, 0, &mut cx),
+            indoc! {"
+                What is the Zed editor?
+                Zed is a modern, high-performance code editor designed from the ground up for speed and collaboration.
+                What is the Zed editor?
+            "},
+            &mut cx,
+        );
 
-                assert_eq!(editor.text(cx), "abc\nabc\ndef\nghi");
-            })
-        });
+        // Select & Copy whole assistant message
+        assert_copy_paste_context_editor(
+            &context_editor,
+            message_range(&context, 1, &mut cx),
+            indoc! {"
+                What is the Zed editor?
+                Zed is a modern, high-performance code editor designed from the ground up for speed and collaboration.
+                What is the Zed editor?
+                Zed is a modern, high-performance code editor designed from the ground up for speed and collaboration.
+            "},
+            &mut cx,
+        );
+    }
 
-        context_editor.update_in(cx, |context_editor, window, cx| {
-            context_editor.editor.update(cx, |editor, cx| {
-                editor.cut(&Default::default(), window, cx);
-                assert_eq!(editor.text(cx), "abc\ndef\nghi");
+    #[gpui::test]
+    async fn test_copy_paste_no_selection(cx: &mut TestAppContext) {
+        let (context, context_editor, mut cx) = setup_context_editor_text(
+            vec![
+                (Role::User, "user1"),
+                (Role::Assistant, "assistant1"),
+                (Role::Assistant, "assistant2"),
+                (Role::User, ""),
+            ],
+            cx,
+        )
+        .await;
+
+        // Copy and paste first assistant message
+        let message_2_range = message_range(&context, 1, &mut cx);
+        assert_copy_paste_context_editor(
+            &context_editor,
+            message_2_range.start..message_2_range.start,
+            indoc! {"
+                user1
+                assistant1
+                assistant2
+                assistant1
+            "},
+            &mut cx,
+        );
 
-                editor.paste(&Default::default(), window, cx);
-                assert_eq!(editor.text(cx), "abc\nabc\ndef\nghi");
-            })
-        });
+        // Copy and cut second assistant message
+        let message_3_range = message_range(&context, 2, &mut cx);
+        assert_copy_paste_context_editor(
+            &context_editor,
+            message_3_range.start..message_3_range.start,
+            indoc! {"
+                user1
+                assistant1
+                assistant2
+                assistant1
+                assistant2
+            "},
+            &mut cx,
+        );
     }
 
     #[gpui::test]
@@ -3408,6 +3427,129 @@ mod tests {
         }
     }
 
+    async fn setup_context_editor_text(
+        messages: Vec<(Role, &str)>,
+        cx: &mut TestAppContext,
+    ) -> (
+        Entity<AssistantContext>,
+        Entity<ContextEditor>,
+        VisualTestContext,
+    ) {
+        cx.update(init_test);
+
+        let fs = FakeFs::new(cx.executor());
+        let context = create_context_with_messages(messages, cx);
+
+        let project = Project::test(fs.clone(), [path!("/test").as_ref()], cx).await;
+        let window = cx.add_window(|window, cx| Workspace::test_new(project.clone(), window, cx));
+        let workspace = window.root(cx).unwrap();
+        let mut cx = VisualTestContext::from_window(*window, cx);
+
+        let context_editor = window
+            .update(&mut cx, |_, window, cx| {
+                cx.new(|cx| {
+                    let editor = ContextEditor::for_context(
+                        context.clone(),
+                        fs,
+                        workspace.downgrade(),
+                        project,
+                        None,
+                        window,
+                        cx,
+                    );
+                    editor
+                })
+            })
+            .unwrap();
+
+        (context, context_editor, cx)
+    }
+
+    fn message_range(
+        context: &Entity<AssistantContext>,
+        message_ix: usize,
+        cx: &mut TestAppContext,
+    ) -> Range<usize> {
+        context.update(cx, |context, cx| {
+            context
+                .messages(cx)
+                .nth(message_ix)
+                .unwrap()
+                .anchor_range
+                .to_offset(&context.buffer().read(cx).snapshot())
+        })
+    }
+
+    fn assert_copy_paste_context_editor<T: editor::ToOffset>(
+        context_editor: &Entity<ContextEditor>,
+        range: Range<T>,
+        expected_text: &str,
+        cx: &mut VisualTestContext,
+    ) {
+        context_editor.update_in(cx, |context_editor, window, cx| {
+            context_editor.editor.update(cx, |editor, cx| {
+                editor.change_selections(None, window, cx, |s| s.select_ranges([range]));
+            });
+
+            context_editor.copy(&Default::default(), window, cx);
+
+            context_editor.editor.update(cx, |editor, cx| {
+                editor.move_to_end(&Default::default(), window, cx);
+            });
+
+            context_editor.paste(&Default::default(), window, cx);
+
+            context_editor.editor.update(cx, |editor, cx| {
+                assert_eq!(editor.text(cx), expected_text);
+            });
+        });
+    }
+
+    fn create_context_with_messages(
+        mut messages: Vec<(Role, &str)>,
+        cx: &mut TestAppContext,
+    ) -> Entity<AssistantContext> {
+        let registry = Arc::new(LanguageRegistry::test(cx.executor()));
+        let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap());
+        cx.new(|cx| {
+            let mut context = AssistantContext::local(
+                registry,
+                None,
+                None,
+                prompt_builder.clone(),
+                Arc::new(SlashCommandWorkingSet::default()),
+                cx,
+            );
+            let mut message_1 = context.messages(cx).next().unwrap();
+            let (role, text) = messages.remove(0);
+
+            loop {
+                if role == message_1.role {
+                    context.buffer().update(cx, |buffer, cx| {
+                        buffer.edit([(message_1.offset_range, text)], None, cx);
+                    });
+                    break;
+                }
+                let mut ids = HashSet::default();
+                ids.insert(message_1.id);
+                context.cycle_message_roles(ids, cx);
+                message_1 = context.messages(cx).next().unwrap();
+            }
+
+            let mut last_message_id = message_1.id;
+            for (role, text) in messages {
+                context.insert_message_after(last_message_id, role, MessageStatus::Done, cx);
+                let message = context.messages(cx).last().unwrap();
+                last_message_id = message.id;
+                context.buffer().update(cx, |buffer, cx| {
+                    buffer.edit([(message.offset_range, text)], None, cx);
+                })
+            }
+
+            context
+        })
+    }
+
     fn init_test(cx: &mut App) {
         let settings_store = SettingsStore::test(cx);
         prompt_store::init(cx);