Add codegen_ranges function in inline_assistant.rs (#43186)

Michael Benfield , Mikayla Maki , and Richard Feldman created

Just a simple refactor.

Release Notes:

- N/A

---------

Co-authored-by: Mikayla Maki <mikayla.c.maki@gmail.com>
Co-authored-by: Richard Feldman <oss@rtfeldman.com>

Change summary

crates/agent_ui/src/inline_assistant.rs | 181 ++++++++++++++------------
1 file changed, 99 insertions(+), 82 deletions(-)

Detailed changes

crates/agent_ui/src/inline_assistant.rs 🔗

@@ -16,6 +16,7 @@ use agent_settings::AgentSettings;
 use anyhow::{Context as _, Result};
 use client::telemetry::Telemetry;
 use collections::{HashMap, HashSet, VecDeque, hash_map};
+use editor::EditorSnapshot;
 use editor::MultiBufferOffset;
 use editor::RowExt;
 use editor::SelectionEffects;
@@ -351,25 +352,20 @@ impl InlineAssistant {
         }
     }
 
-    pub fn assist(
+    fn codegen_ranges(
         &mut self,
         editor: &Entity<Editor>,
-        workspace: WeakEntity<Workspace>,
-        context_store: Entity<ContextStore>,
-        project: WeakEntity<Project>,
-        prompt_store: Option<Entity<PromptStore>>,
-        thread_store: Option<WeakEntity<HistoryStore>>,
-        initial_prompt: Option<String>,
+        snapshot: &EditorSnapshot,
         window: &mut Window,
         cx: &mut App,
-    ) {
-        let (snapshot, initial_selections, newest_selection) = editor.update(cx, |editor, cx| {
-            let snapshot = editor.snapshot(window, cx);
-            let selections = editor.selections.all::<Point>(&snapshot.display_snapshot);
-            let newest_selection = editor
-                .selections
-                .newest::<Point>(&snapshot.display_snapshot);
-            (snapshot, selections, newest_selection)
+    ) -> Option<(Vec<Range<Anchor>>, Selection<Point>)> {
+        let (initial_selections, newest_selection) = editor.update(cx, |editor, _| {
+            (
+                editor.selections.all::<Point>(&snapshot.display_snapshot),
+                editor
+                    .selections
+                    .newest::<Point>(&snapshot.display_snapshot),
+            )
         });
 
         // Check if there is already an inline assistant that contains the
@@ -382,7 +378,7 @@ impl InlineAssistant {
                     && newest_selection.end.row <= range.end.row
                 {
                     self.focus_assist(*assist_id, window, cx);
-                    return;
+                    return None;
                 }
             }
         }
@@ -474,6 +470,26 @@ impl InlineAssistant {
             }
         }
 
+        Some((codegen_ranges, newest_selection))
+    }
+
+    fn batch_assist(
+        &mut self,
+        editor: &Entity<Editor>,
+        workspace: WeakEntity<Workspace>,
+        context_store: Entity<ContextStore>,
+        project: WeakEntity<Project>,
+        prompt_store: Option<Entity<PromptStore>>,
+        thread_store: Option<WeakEntity<HistoryStore>>,
+        initial_prompt: Option<String>,
+        window: &mut Window,
+        codegen_ranges: &[Range<Anchor>],
+        newest_selection: Option<Selection<Point>>,
+        initial_transaction_id: Option<TransactionId>,
+        cx: &mut App,
+    ) -> Option<InlineAssistId> {
+        let snapshot = editor.update(cx, |editor, cx| editor.snapshot(window, cx));
+
         let assist_group_id = self.next_assist_group_id.post_inc();
         let prompt_buffer = cx.new(|cx| {
             MultiBuffer::singleton(
@@ -484,13 +500,14 @@ impl InlineAssistant {
 
         let mut assists = Vec::new();
         let mut assist_to_focus = None;
+
         for range in codegen_ranges {
             let assist_id = self.next_assist_id.post_inc();
             let codegen = cx.new(|cx| {
                 BufferCodegen::new(
                     editor.read(cx).buffer().clone(),
                     range.clone(),
-                    None,
+                    initial_transaction_id,
                     context_store.clone(),
                     project.clone(),
                     prompt_store.clone(),
@@ -518,11 +535,13 @@ impl InlineAssistant {
                 )
             });
 
-            if assist_to_focus.is_none() {
+            if let Some(newest_selection) = newest_selection.as_ref()
+                && assist_to_focus.is_none()
+            {
                 let focus_assist = if newest_selection.reversed {
-                    range.start.to_point(snapshot) == newest_selection.start
+                    range.start.to_point(&snapshot) == newest_selection.start
                 } else {
-                    range.end.to_point(snapshot) == newest_selection.end
+                    range.end.to_point(&snapshot) == newest_selection.end
                 };
                 if focus_assist {
                     assist_to_focus = Some(assist_id);
@@ -534,7 +553,7 @@ impl InlineAssistant {
 
             assists.push((
                 assist_id,
-                range,
+                range.clone(),
                 prompt_editor,
                 prompt_block_id,
                 end_block_id,
@@ -545,6 +564,15 @@ impl InlineAssistant {
             .assists_by_editor
             .entry(editor.downgrade())
             .or_insert_with(|| EditorInlineAssists::new(editor, window, cx));
+
+        let assist_to_focus = if let Some(focus_id) = assist_to_focus {
+            Some(focus_id)
+        } else if assists.len() >= 1 {
+            Some(assists[0].0)
+        } else {
+            None
+        };
+
         let mut assist_group = InlineAssistGroup::new();
         for (assist_id, range, prompt_editor, prompt_block_id, end_block_id) in assists {
             let codegen = prompt_editor.read(cx).codegen().clone();
@@ -568,8 +596,47 @@ impl InlineAssistant {
             assist_group.assist_ids.push(assist_id);
             editor_assists.assist_ids.push(assist_id);
         }
+
         self.assist_groups.insert(assist_group_id, assist_group);
 
+        assist_to_focus
+    }
+
+    pub fn assist(
+        &mut self,
+        editor: &Entity<Editor>,
+        workspace: WeakEntity<Workspace>,
+        context_store: Entity<ContextStore>,
+        project: WeakEntity<Project>,
+        prompt_store: Option<Entity<PromptStore>>,
+        thread_store: Option<WeakEntity<HistoryStore>>,
+        initial_prompt: Option<String>,
+        window: &mut Window,
+        cx: &mut App,
+    ) {
+        let snapshot = editor.update(cx, |editor, cx| editor.snapshot(window, cx));
+
+        let Some((codegen_ranges, newest_selection)) =
+            self.codegen_ranges(editor, &snapshot, window, cx)
+        else {
+            return;
+        };
+
+        let assist_to_focus = self.batch_assist(
+            editor,
+            workspace,
+            context_store,
+            project,
+            prompt_store,
+            thread_store,
+            initial_prompt,
+            window,
+            &codegen_ranges,
+            Some(newest_selection),
+            None,
+            cx,
+        );
+
         if let Some(assist_id) = assist_to_focus {
             self.focus_assist(assist_id, window, cx);
         }
@@ -588,12 +655,6 @@ impl InlineAssistant {
         window: &mut Window,
         cx: &mut App,
     ) -> InlineAssistId {
-        let assist_group_id = self.next_assist_group_id.post_inc();
-        let prompt_buffer = cx.new(|cx| Buffer::local(&initial_prompt, cx));
-        let prompt_buffer = cx.new(|cx| MultiBuffer::singleton(prompt_buffer, cx));
-
-        let assist_id = self.next_assist_id.post_inc();
-
         let buffer = editor.read(cx).buffer().clone();
         {
             let snapshot = buffer.read(cx).read(cx);
@@ -604,66 +665,22 @@ impl InlineAssistant {
         let project = workspace.read(cx).project().downgrade();
         let context_store = cx.new(|_cx| ContextStore::new(project.clone()));
 
-        let codegen = cx.new(|cx| {
-            BufferCodegen::new(
-                editor.read(cx).buffer().clone(),
-                range.clone(),
-                initial_transaction_id,
-                context_store.clone(),
-                project,
-                prompt_store.clone(),
-                self.telemetry.clone(),
-                self.prompt_builder.clone(),
-                cx,
-            )
-        });
-
-        let editor_margins = Arc::new(Mutex::new(EditorMargins::default()));
-        let prompt_editor = cx.new(|cx| {
-            PromptEditor::new_buffer(
-                assist_id,
-                editor_margins,
-                self.prompt_history.clone(),
-                prompt_buffer.clone(),
-                codegen.clone(),
-                self.fs.clone(),
-                context_store,
+        let assist_id = self
+            .batch_assist(
+                editor,
                 workspace.downgrade(),
+                context_store,
+                project,
+                prompt_store,
                 thread_store,
-                prompt_store.map(|s| s.downgrade()),
+                Some(initial_prompt),
                 window,
+                &[range],
+                None,
+                initial_transaction_id,
                 cx,
             )
-        });
-
-        let [prompt_block_id, end_block_id] =
-            self.insert_assist_blocks(editor, &range, &prompt_editor, cx);
-
-        let editor_assists = self
-            .assists_by_editor
-            .entry(editor.downgrade())
-            .or_insert_with(|| EditorInlineAssists::new(editor, window, cx));
-
-        let mut assist_group = InlineAssistGroup::new();
-        self.assists.insert(
-            assist_id,
-            InlineAssist::new(
-                assist_id,
-                assist_group_id,
-                editor,
-                &prompt_editor,
-                prompt_block_id,
-                end_block_id,
-                range,
-                codegen.clone(),
-                workspace.downgrade(),
-                window,
-                cx,
-            ),
-        );
-        assist_group.assist_ids.push(assist_id);
-        editor_assists.assist_ids.push(assist_id);
-        self.assist_groups.insert(assist_group_id, assist_group);
+            .expect("batch_assist returns an id if there's only one range");
 
         if focus {
             self.focus_assist(assist_id, window, cx);