Support wrapping and hard newlines in inline assistant (#12683)

Antonio Scandurra created

Release Notes:

- Improved UX for the inline assistant. It will now automatically wrap
when the text gets too long, and you can insert newlines using
`shift-enter`.

Change summary

crates/assistant/src/inline_assistant.rs   | 137 +++++++++++++++++----
crates/editor/src/display_map.rs           |  51 +++++++
crates/editor/src/display_map/block_map.rs | 151 +++++++++++++++++++++++
crates/editor/src/editor.rs                |  14 +
4 files changed, 315 insertions(+), 38 deletions(-)

Detailed changes

crates/assistant/src/inline_assistant.rs 🔗

@@ -7,7 +7,9 @@ use client::telemetry::Telemetry;
 use collections::{hash_map, HashMap, HashSet, VecDeque};
 use editor::{
     actions::{MoveDown, MoveUp},
-    display_map::{BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle},
+    display_map::{
+        BlockContext, BlockDisposition, BlockId, BlockProperties, BlockStyle, RenderBlock,
+    },
     scroll::{Autoscroll, AutoscrollStrategy},
     Anchor, Editor, EditorElement, EditorEvent, EditorStyle, GutterDimensions, MultiBuffer,
     MultiBufferSnapshot, ToOffset, ToPoint,
@@ -105,11 +107,11 @@ impl InlineAssistant {
             )
         });
 
-        let measurements = Arc::new(Mutex::new(GutterDimensions::default()));
-        let inline_assistant = cx.new_view(|cx| {
+        let gutter_dimensions = Arc::new(Mutex::new(GutterDimensions::default()));
+        let inline_assist_editor = cx.new_view(|cx| {
             InlineAssistEditor::new(
                 inline_assist_id,
-                measurements.clone(),
+                gutter_dimensions.clone(),
                 self.prompt_history.clone(),
                 codegen.clone(),
                 cx,
@@ -121,16 +123,13 @@ impl InlineAssistant {
             });
             editor.insert_blocks(
                 [BlockProperties {
-                    style: BlockStyle::Flex,
+                    style: BlockStyle::Sticky,
                     position: snapshot.anchor_before(Point::new(point_selection.head().row, 0)),
-                    height: 2,
-                    render: Box::new({
-                        let inline_assistant = inline_assistant.clone();
-                        move |cx: &mut BlockContext| {
-                            *measurements.lock() = *cx.gutter_dimensions;
-                            inline_assistant.clone().into_any_element()
-                        }
-                    }),
+                    height: inline_assist_editor.read(cx).height_in_lines,
+                    render: build_inline_assist_editor_renderer(
+                        &inline_assist_editor,
+                        gutter_dimensions,
+                    ),
                     disposition: if selection.reversed {
                         BlockDisposition::Above
                     } else {
@@ -147,22 +146,24 @@ impl InlineAssistant {
             PendingInlineAssist {
                 include_conversation,
                 editor: editor.downgrade(),
-                inline_assistant: Some((block_id, inline_assistant.clone())),
+                inline_assist_editor: Some((block_id, inline_assist_editor.clone())),
                 codegen: codegen.clone(),
                 workspace,
                 _subscriptions: vec![
-                    cx.subscribe(&inline_assistant, |inline_assistant, event, cx| {
+                    cx.subscribe(&inline_assist_editor, |inline_assist_editor, event, cx| {
                         InlineAssistant::update_global(cx, |this, cx| {
-                            this.handle_inline_assistant_event(inline_assistant, event, cx)
+                            this.handle_inline_assistant_event(inline_assist_editor, event, cx)
                         })
                     }),
                     cx.subscribe(editor, {
-                        let inline_assistant = inline_assistant.downgrade();
+                        let inline_assist_editor = inline_assist_editor.downgrade();
                         move |editor, event, cx| {
-                            if let Some(inline_assistant) = inline_assistant.upgrade() {
+                            if let Some(inline_assist_editor) = inline_assist_editor.upgrade() {
                                 if let EditorEvent::SelectionsChanged { local } = event {
                                     if *local
-                                        && inline_assistant.focus_handle(cx).contains_focused(cx)
+                                        && inline_assist_editor
+                                            .focus_handle(cx)
+                                            .contains_focused(cx)
                                     {
                                         cx.focus_view(&editor);
                                     }
@@ -199,7 +200,7 @@ impl InlineAssistant {
                                     .error()
                                     .map(|error| format!("Inline assistant error: {}", error));
                                 if let Some(error) = error {
-                                    if pending_assist.inline_assistant.is_none() {
+                                    if pending_assist.inline_assist_editor.is_none() {
                                         if let Some(workspace) = pending_assist
                                             .workspace
                                             .as_ref()
@@ -243,11 +244,11 @@ impl InlineAssistant {
 
     fn handle_inline_assistant_event(
         &mut self,
-        inline_assistant: View<InlineAssistEditor>,
+        inline_assist_editor: View<InlineAssistEditor>,
         event: &InlineAssistEditorEvent,
         cx: &mut WindowContext,
     ) {
-        let assist_id = inline_assistant.read(cx).id;
+        let assist_id = inline_assist_editor.read(cx).id;
         match event {
             InlineAssistEditorEvent::Confirmed { prompt } => {
                 self.confirm_inline_assist(assist_id, prompt, cx);
@@ -258,6 +259,9 @@ impl InlineAssistant {
             InlineAssistEditorEvent::Dismissed => {
                 self.hide_inline_assist(assist_id, cx);
             }
+            InlineAssistEditorEvent::Resized { height_in_lines } => {
+                self.resize_inline_assist(assist_id, *height_in_lines, cx);
+            }
         }
     }
 
@@ -311,10 +315,12 @@ impl InlineAssistant {
     fn hide_inline_assist(&mut self, assist_id: InlineAssistId, cx: &mut WindowContext) {
         if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) {
             if let Some(editor) = pending_assist.editor.upgrade() {
-                if let Some((block_id, inline_assistant)) = pending_assist.inline_assistant.take() {
+                if let Some((block_id, inline_assist_editor)) =
+                    pending_assist.inline_assist_editor.take()
+                {
                     editor.update(cx, |editor, cx| {
                         editor.remove_blocks(HashSet::from_iter([block_id]), None, cx);
-                        if inline_assistant.focus_handle(cx).contains_focused(cx) {
+                        if inline_assist_editor.focus_handle(cx).contains_focused(cx) {
                             editor.focus(cx);
                         }
                     });
@@ -323,6 +329,39 @@ impl InlineAssistant {
         }
     }
 
+    fn resize_inline_assist(
+        &mut self,
+        assist_id: InlineAssistId,
+        height_in_lines: u8,
+        cx: &mut WindowContext,
+    ) {
+        if let Some(pending_assist) = self.pending_assists.get_mut(&assist_id) {
+            if let Some(editor) = pending_assist.editor.upgrade() {
+                if let Some((block_id, inline_assist_editor)) =
+                    pending_assist.inline_assist_editor.as_ref()
+                {
+                    let gutter_dimensions = inline_assist_editor.read(cx).gutter_dimensions.clone();
+                    let mut new_blocks = HashMap::default();
+                    new_blocks.insert(
+                        *block_id,
+                        (
+                            Some(height_in_lines),
+                            build_inline_assist_editor_renderer(
+                                inline_assist_editor,
+                                gutter_dimensions,
+                            ),
+                        ),
+                    );
+                    editor.update(cx, |editor, cx| {
+                        editor
+                            .display_map
+                            .update(cx, |map, cx| map.replace_blocks(new_blocks, cx))
+                    });
+                }
+            }
+        }
+    }
+
     fn confirm_inline_assist(
         &mut self,
         assist_id: InlineAssistId,
@@ -498,6 +537,17 @@ impl InlineAssistant {
     }
 }
 
+fn build_inline_assist_editor_renderer(
+    editor: &View<InlineAssistEditor>,
+    gutter_dimensions: Arc<Mutex<GutterDimensions>>,
+) -> RenderBlock {
+    let editor = editor.clone();
+    Box::new(move |cx: &mut BlockContext| {
+        *gutter_dimensions.lock() = *cx.gutter_dimensions;
+        editor.clone().into_any_element()
+    })
+}
+
 #[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)]
 struct InlineAssistId(usize);
 
@@ -513,10 +563,12 @@ enum InlineAssistEditorEvent {
     Confirmed { prompt: String },
     Canceled,
     Dismissed,
+    Resized { height_in_lines: u8 },
 }
 
 struct InlineAssistEditor {
     id: InlineAssistId,
+    height_in_lines: u8,
     prompt_editor: View<Editor>,
     confirmed: bool,
     gutter_dimensions: Arc<Mutex<GutterDimensions>>,
@@ -535,7 +587,7 @@ impl Render for InlineAssistEditor {
         let icon_size = IconSize::default();
         h_flex()
             .w_full()
-            .py_2()
+            .py_1p5()
             .border_y_1()
             .border_color(cx.theme().colors().border)
             .bg(cx.theme().colors().editor_background)
@@ -564,7 +616,7 @@ impl Render for InlineAssistEditor {
                         None
                     }),
             )
-            .child(h_flex().flex_1().child(self.render_prompt_editor(cx)))
+            .child(div().flex_1().child(self.render_prompt_editor(cx)))
     }
 }
 
@@ -575,6 +627,8 @@ impl FocusableView for InlineAssistEditor {
 }
 
 impl InlineAssistEditor {
+    const MAX_LINES: u8 = 8;
+
     #[allow(clippy::too_many_arguments)]
     fn new(
         id: InlineAssistId,
@@ -584,7 +638,8 @@ impl InlineAssistEditor {
         cx: &mut ViewContext<Self>,
     ) -> Self {
         let prompt_editor = cx.new_view(|cx| {
-            let mut editor = Editor::single_line(cx);
+            let mut editor = Editor::auto_height(Self::MAX_LINES as usize, cx);
+            editor.set_soft_wrap_mode(language::language_settings::SoftWrap::EditorWidth, cx);
             let placeholder = match codegen.read(cx).kind() {
                 CodegenKind::Transform { .. } => "Enter transformation prompt…",
                 CodegenKind::Generate { .. } => "Enter generation prompt…",
@@ -596,11 +651,13 @@ impl InlineAssistEditor {
 
         let subscriptions = vec![
             cx.observe(&codegen, Self::handle_codegen_changed),
+            cx.observe(&prompt_editor, Self::handle_prompt_editor_changed),
             cx.subscribe(&prompt_editor, Self::handle_prompt_editor_events),
         ];
 
-        Self {
+        let mut this = Self {
             id,
+            height_in_lines: 1,
             prompt_editor,
             confirmed: false,
             gutter_dimensions,
@@ -609,9 +666,31 @@ impl InlineAssistEditor {
             pending_prompt: String::new(),
             codegen,
             _subscriptions: subscriptions,
+        };
+        this.count_lines(cx);
+        this
+    }
+
+    fn count_lines(&mut self, cx: &mut ViewContext<Self>) {
+        let height_in_lines = cmp::max(
+            2, // Make the editor at least two lines tall, to account for padding.
+            cmp::min(
+                self.prompt_editor
+                    .update(cx, |editor, cx| editor.max_point(cx).row().0 + 1),
+                Self::MAX_LINES as u32,
+            ),
+        ) as u8;
+
+        if height_in_lines != self.height_in_lines {
+            self.height_in_lines = height_in_lines;
+            cx.emit(InlineAssistEditorEvent::Resized { height_in_lines });
         }
     }
 
+    fn handle_prompt_editor_changed(&mut self, _: View<Editor>, cx: &mut ViewContext<Self>) {
+        self.count_lines(cx);
+    }
+
     fn handle_prompt_editor_events(
         &mut self,
         _: View<Editor>,
@@ -727,7 +806,7 @@ impl InlineAssistEditor {
 
 struct PendingInlineAssist {
     editor: WeakView<Editor>,
-    inline_assistant: Option<(BlockId, View<InlineAssistEditor>)>,
+    inline_assist_editor: Option<(BlockId, View<InlineAssistEditor>)>,
     codegen: Model<Codegen>,
     _subscriptions: Vec<Subscription>,
     workspace: Option<WeakView<Workspace>>,

crates/editor/src/display_map.rs 🔗

@@ -277,8 +277,55 @@ impl DisplayMap {
         block_map.insert(blocks)
     }
 
-    pub fn replace_blocks(&mut self, styles: HashMap<BlockId, RenderBlock>) {
-        self.block_map.replace(styles);
+    pub fn replace_blocks(
+        &mut self,
+        heights_and_renderers: HashMap<BlockId, (Option<u8>, RenderBlock)>,
+        cx: &mut ModelContext<Self>,
+    ) {
+        //
+        // Note: previous implementation of `replace_blocks` simply called
+        // `self.block_map.replace(styles)` which just modified the render by replacing
+        // the `RenderBlock` with the new one.
+        //
+        // ```rust
+        //  for block in &self.blocks {
+        //           if let Some(render) = renderers.remove(&block.id) {
+        //               *block.render.lock() = render;
+        //           }
+        //       }
+        // ```
+        //
+        // If height changes however, we need to update the tree. There's a performance
+        // cost to this, so we'll split the replace blocks into handling the old behavior
+        // directly and the new behavior separately.
+        //
+        //
+        let mut only_renderers = HashMap::<BlockId, RenderBlock>::default();
+        let mut full_replace = HashMap::<BlockId, (u8, RenderBlock)>::default();
+        for (id, (height, render)) in heights_and_renderers {
+            if let Some(height) = height {
+                full_replace.insert(id, (height, render));
+            } else {
+                only_renderers.insert(id, render);
+            }
+        }
+        self.block_map.replace_renderers(only_renderers);
+
+        if full_replace.is_empty() {
+            return;
+        }
+
+        let snapshot = self.buffer.read(cx).snapshot(cx);
+        let edits = self.buffer_subscription.consume().into_inner();
+        let tab_size = Self::tab_size(&self.buffer, cx);
+        let (snapshot, edits) = self.inlay_map.sync(snapshot, edits);
+        let (snapshot, edits) = self.fold_map.read(snapshot, edits);
+        let (snapshot, edits) = self.tab_map.sync(snapshot, edits, tab_size);
+        let (snapshot, edits) = self
+            .wrap_map
+            .update(cx, |map, cx| map.sync(snapshot, edits, cx));
+        let mut block_map = self.block_map.write(snapshot, edits);
+        block_map.replace(full_replace);
     }
 
     pub fn remove_blocks(&mut self, ids: HashSet<BlockId>, cx: &mut ModelContext<Self>) {

crates/editor/src/display_map/block_map.rs 🔗

@@ -467,8 +467,8 @@ impl BlockMap {
         *transforms = new_transforms;
     }
 
-    pub fn replace(&mut self, mut renderers: HashMap<BlockId, RenderBlock>) {
-        for block in &self.blocks {
+    pub fn replace_renderers(&mut self, mut renderers: HashMap<BlockId, RenderBlock>) {
+        for block in &mut self.blocks {
             if let Some(render) = renderers.remove(&block.id) {
                 *block.render.lock() = render;
             }
@@ -659,6 +659,48 @@ impl<'a> BlockMapWriter<'a> {
         ids
     }
 
+    pub fn replace(&mut self, mut heights_and_renderers: HashMap<BlockId, (u8, RenderBlock)>) {
+        let wrap_snapshot = &*self.0.wrap_snapshot.borrow();
+        let buffer = wrap_snapshot.buffer_snapshot();
+        let mut edits = Patch::default();
+        let mut last_block_buffer_row = None;
+
+        for block in &mut self.0.blocks {
+            if let Some((new_height, render)) = heights_and_renderers.remove(&block.id) {
+                if block.height != new_height {
+                    let new_block = Block {
+                        id: block.id,
+                        position: block.position,
+                        height: new_height,
+                        style: block.style,
+                        render: Mutex::new(render),
+                        disposition: block.disposition,
+                    };
+                    *block = Arc::new(new_block);
+
+                    let buffer_row = block.position.to_point(buffer).row;
+                    if last_block_buffer_row != Some(buffer_row) {
+                        last_block_buffer_row = Some(buffer_row);
+                        let wrap_row = wrap_snapshot
+                            .make_wrap_point(Point::new(buffer_row, 0), Bias::Left)
+                            .row();
+                        let start_row =
+                            wrap_snapshot.prev_row_boundary(WrapPoint::new(wrap_row, 0));
+                        let end_row = wrap_snapshot
+                            .next_row_boundary(WrapPoint::new(wrap_row, 0))
+                            .unwrap_or(wrap_snapshot.max_point().row() + 1);
+                        edits.push(Edit {
+                            old: start_row..end_row,
+                            new: start_row..end_row,
+                        })
+                    }
+                }
+            }
+        }
+
+        self.0.sync(wrap_snapshot, edits);
+    }
+
     pub fn remove(&mut self, block_ids: HashSet<BlockId>) {
         let wrap_snapshot = &*self.0.wrap_snapshot.borrow();
         let buffer = wrap_snapshot.buffer_snapshot();
@@ -1305,6 +1347,111 @@ mod tests {
         assert_eq!(snapshot.text(), "aaa\n\nb!!!\n\n\nbb\nccc\nddd\n\n\n");
     }
 
+    #[gpui::test]
+    fn test_replace_with_heights(cx: &mut gpui::TestAppContext) {
+        let _update = cx.update(|cx| init_test(cx));
+
+        let text = "aaa\nbbb\nccc\nddd";
+
+        let buffer = cx.update(|cx| MultiBuffer::build_simple(text, cx));
+        let buffer_snapshot = cx.update(|cx| buffer.read(cx).snapshot(cx));
+        let _subscription = buffer.update(cx, |buffer, _| buffer.subscribe());
+        let (_inlay_map, inlay_snapshot) = InlayMap::new(buffer_snapshot.clone());
+        let (_fold_map, fold_snapshot) = FoldMap::new(inlay_snapshot);
+        let (_tab_map, tab_snapshot) = TabMap::new(fold_snapshot, 1.try_into().unwrap());
+        let (_wrap_map, wraps_snapshot) =
+            cx.update(|cx| WrapMap::new(tab_snapshot, font("Helvetica"), px(14.0), None, cx));
+        let mut block_map = BlockMap::new(wraps_snapshot.clone(), false, 1, 1, 0);
+
+        let mut writer = block_map.write(wraps_snapshot.clone(), Default::default());
+        let block_ids = writer.insert(vec![
+            BlockProperties {
+                style: BlockStyle::Fixed,
+                position: buffer_snapshot.anchor_after(Point::new(1, 0)),
+                height: 1,
+                disposition: BlockDisposition::Above,
+                render: Box::new(|_| div().into_any()),
+            },
+            BlockProperties {
+                style: BlockStyle::Fixed,
+                position: buffer_snapshot.anchor_after(Point::new(1, 2)),
+                height: 2,
+                disposition: BlockDisposition::Above,
+                render: Box::new(|_| div().into_any()),
+            },
+            BlockProperties {
+                style: BlockStyle::Fixed,
+                position: buffer_snapshot.anchor_after(Point::new(3, 3)),
+                height: 3,
+                disposition: BlockDisposition::Below,
+                render: Box::new(|_| div().into_any()),
+            },
+        ]);
+
+        {
+            let snapshot = block_map.read(wraps_snapshot.clone(), Default::default());
+            assert_eq!(snapshot.text(), "aaa\n\n\n\nbbb\nccc\nddd\n\n\n");
+
+            let mut block_map_writer = block_map.write(wraps_snapshot.clone(), Default::default());
+
+            let mut hash_map = HashMap::default();
+            let render: RenderBlock = Box::new(|_| div().into_any());
+            hash_map.insert(block_ids[0], (2_u8, render));
+            block_map_writer.replace(hash_map);
+            let snapshot = block_map.read(wraps_snapshot.clone(), Default::default());
+            assert_eq!(snapshot.text(), "aaa\n\n\n\n\nbbb\nccc\nddd\n\n\n");
+        }
+
+        {
+            let mut block_map_writer = block_map.write(wraps_snapshot.clone(), Default::default());
+
+            let mut hash_map = HashMap::default();
+            let render: RenderBlock = Box::new(|_| div().into_any());
+            hash_map.insert(block_ids[0], (1_u8, render));
+            block_map_writer.replace(hash_map);
+
+            let snapshot = block_map.read(wraps_snapshot.clone(), Default::default());
+            assert_eq!(snapshot.text(), "aaa\n\n\n\nbbb\nccc\nddd\n\n\n");
+        }
+
+        {
+            let mut block_map_writer = block_map.write(wraps_snapshot.clone(), Default::default());
+
+            let mut hash_map = HashMap::default();
+            let render: RenderBlock = Box::new(|_| div().into_any());
+            hash_map.insert(block_ids[0], (0_u8, render));
+            block_map_writer.replace(hash_map);
+
+            let snapshot = block_map.read(wraps_snapshot.clone(), Default::default());
+            assert_eq!(snapshot.text(), "aaa\n\n\nbbb\nccc\nddd\n\n\n");
+        }
+
+        {
+            let mut block_map_writer = block_map.write(wraps_snapshot.clone(), Default::default());
+
+            let mut hash_map = HashMap::default();
+            let render: RenderBlock = Box::new(|_| div().into_any());
+            hash_map.insert(block_ids[0], (3_u8, render));
+            block_map_writer.replace(hash_map);
+
+            let snapshot = block_map.read(wraps_snapshot.clone(), Default::default());
+            assert_eq!(snapshot.text(), "aaa\n\n\n\n\n\nbbb\nccc\nddd\n\n\n");
+        }
+
+        {
+            let mut block_map_writer = block_map.write(wraps_snapshot.clone(), Default::default());
+
+            let mut hash_map = HashMap::default();
+            let render: RenderBlock = Box::new(|_| div().into_any());
+            hash_map.insert(block_ids[0], (3_u8, render));
+            block_map_writer.replace(hash_map);
+
+            let snapshot = block_map.read(wraps_snapshot.clone(), Default::default());
+            // Same height as before, should remain the same
+            assert_eq!(snapshot.text(), "aaa\n\n\n\n\n\nbbb\nccc\nddd\n\n\n");
+        }
+    }
+
     #[gpui::test]
     fn test_blocks_on_wrapped_lines(cx: &mut gpui::TestAppContext) {
         cx.update(|cx| init_test(cx));

crates/editor/src/editor.rs 🔗

@@ -9263,11 +9263,15 @@ impl Editor {
                 for (block_id, diagnostic) in &active_diagnostics.blocks {
                     new_styles.insert(
                         *block_id,
-                        diagnostic_block_renderer(diagnostic.clone(), is_valid),
+                        (
+                            None,
+                            diagnostic_block_renderer(diagnostic.clone(), is_valid),
+                        ),
                     );
                 }
-                self.display_map
-                    .update(cx, |display_map, _| display_map.replace_blocks(new_styles));
+                self.display_map.update(cx, |display_map, cx| {
+                    display_map.replace_blocks(new_styles, cx)
+                });
             }
         }
     }
@@ -9624,12 +9628,12 @@ impl Editor {
 
     pub fn replace_blocks(
         &mut self,
-        blocks: HashMap<BlockId, RenderBlock>,
+        blocks: HashMap<BlockId, (Option<u8>, RenderBlock)>,
         autoscroll: Option<Autoscroll>,
         cx: &mut ViewContext<Self>,
     ) {
         self.display_map
-            .update(cx, |display_map, _| display_map.replace_blocks(blocks));
+            .update(cx, |display_map, cx| display_map.replace_blocks(blocks, cx));
         if let Some(autoscroll) = autoscroll {
             self.request_autoscroll(autoscroll, cx);
         }