assistant: Add imports in a single area when using workflows (#16355)

Bennet Bo Fenner , Kirill , and Thorsten created

Co-Authored-by: Kirill <kirill@zed.dev>

Release Notes:

- N/A

---------

Co-authored-by: Kirill <kirill@zed.dev>
Co-authored-by: Thorsten <thorsten@zed.dev>

Change summary

assets/prompts/step_resolution.hbs       |  3 
crates/assistant/src/assistant_panel.rs  | 50 +++++---------
crates/assistant/src/inline_assistant.rs | 35 ++++++++--
crates/assistant/src/workflow.rs         | 90 ++++++++++++++++++-------
crates/language/src/outline.rs           |  2 
5 files changed, 115 insertions(+), 65 deletions(-)

Detailed changes

assets/prompts/step_resolution.hbs 🔗

@@ -15,6 +15,7 @@ With each location, you will produce a brief, one-line description of the change
 - When generating multiple suggestions, ensure the descriptions are specific to each individual operation.
 - Avoid referring to the location in the description. Focus on the change to be made, not the location where it's made. That's implicit with the symbol you provide.
 - Don't generate multiple suggestions at the same location. Instead, combine them together in a single operation with a succinct combined description.
+- To add imports respond with a suggestion where the `"symbol"` key is set to `"#imports"`
 </guidelines>
 </overview>
 
@@ -203,6 +204,7 @@ Add a 'use std::fmt;' statement at the beginning of the file
     {
       "kind": "PrependChild",
       "path": "src/vehicle.rs",
+      "symbol": "#imports",
       "description": "Add 'use std::fmt' statement"
     }
   ]
@@ -413,6 +415,7 @@ Add a 'load_from_file' method to Config and import necessary modules
     {
       "kind": "PrependChild",
       "path": "src/config.rs",
+      "symbol": "#imports",
       "description": "Import std::fs and std::io modules"
     },
     {

crates/assistant/src/assistant_panel.rs 🔗

@@ -1719,7 +1719,6 @@ struct WorkflowAssist {
     editor: WeakView<Editor>,
     editor_was_open: bool,
     assist_ids: Vec<InlineAssistId>,
-    _observe_assist_status: Task<()>,
 }
 
 pub struct ContextEditor {
@@ -1862,13 +1861,25 @@ impl ContextEditor {
         if let Some(workflow_step) = self.workflow_steps.get(&range) {
             if let Some(assist) = workflow_step.assist.as_ref() {
                 let assist_ids = assist.assist_ids.clone();
-                cx.window_context().defer(|cx| {
-                    InlineAssistant::update_global(cx, |assistant, cx| {
-                        for assist_id in assist_ids {
-                            assistant.start_assist(assist_id, cx);
+                cx.spawn(|this, mut cx| async move {
+                    for assist_id in assist_ids {
+                        let mut receiver = this.update(&mut cx, |_, cx| {
+                            cx.window_context().defer(move |cx| {
+                                InlineAssistant::update_global(cx, |assistant, cx| {
+                                    assistant.start_assist(assist_id, cx);
+                                })
+                            });
+                            InlineAssistant::update_global(cx, |assistant, _| {
+                                assistant.observe_assist(assist_id)
+                            })
+                        })?;
+                        while !receiver.borrow().is_done() {
+                            let _ = receiver.changed().await;
                         }
-                    })
-                });
+                    }
+                    anyhow::Ok(())
+                })
+                .detach_and_log_err(cx);
             }
         }
     }
@@ -3006,35 +3017,10 @@ impl ContextEditor {
             }
         }
 
-        let mut observations = Vec::new();
-        InlineAssistant::update_global(cx, |assistant, _cx| {
-            for assist_id in &assist_ids {
-                observations.push(assistant.observe_assist(*assist_id));
-            }
-        });
-
         Some(WorkflowAssist {
             assist_ids,
             editor: editor.downgrade(),
             editor_was_open,
-            _observe_assist_status: cx.spawn(|this, mut cx| async move {
-                while !observations.is_empty() {
-                    let (result, ix, _) = futures::future::select_all(
-                        observations
-                            .iter_mut()
-                            .map(|observation| Box::pin(observation.changed())),
-                    )
-                    .await;
-
-                    if result.is_err() {
-                        observations.remove(ix);
-                    }
-
-                    if this.update(&mut cx, |_, cx| cx.notify()).is_err() {
-                        break;
-                    }
-                }
-            }),
         })
     }
 

crates/assistant/src/inline_assistant.rs 🔗

@@ -76,8 +76,13 @@ pub struct InlineAssistant {
     assists: HashMap<InlineAssistId, InlineAssist>,
     assists_by_editor: HashMap<WeakView<Editor>, EditorInlineAssists>,
     assist_groups: HashMap<InlineAssistGroupId, InlineAssistGroup>,
-    assist_observations:
-        HashMap<InlineAssistId, (async_watch::Sender<()>, async_watch::Receiver<()>)>,
+    assist_observations: HashMap<
+        InlineAssistId,
+        (
+            async_watch::Sender<AssistStatus>,
+            async_watch::Receiver<AssistStatus>,
+        ),
+    >,
     confirmed_assists: HashMap<InlineAssistId, Model<Codegen>>,
     prompt_history: VecDeque<String>,
     prompt_builder: Arc<PromptBuilder>,
@@ -85,6 +90,19 @@ pub struct InlineAssistant {
     fs: Arc<dyn Fs>,
 }
 
+pub enum AssistStatus {
+    Idle,
+    Started,
+    Stopped,
+    Finished,
+}
+
+impl AssistStatus {
+    pub fn is_done(&self) -> bool {
+        matches!(self, Self::Stopped | Self::Finished)
+    }
+}
+
 impl Global for InlineAssistant {}
 
 impl InlineAssistant {
@@ -925,7 +943,7 @@ impl InlineAssistant {
             .log_err();
 
         if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
-            tx.send(()).ok();
+            tx.send(AssistStatus::Started).ok();
         }
     }
 
@@ -939,7 +957,7 @@ impl InlineAssistant {
         assist.codegen.update(cx, |codegen, cx| codegen.stop(cx));
 
         if let Some((tx, _)) = self.assist_observations.get(&assist_id) {
-            tx.send(()).ok();
+            tx.send(AssistStatus::Stopped).ok();
         }
     }
 
@@ -1141,11 +1159,14 @@ impl InlineAssistant {
         })
     }
 
-    pub fn observe_assist(&mut self, assist_id: InlineAssistId) -> async_watch::Receiver<()> {
+    pub fn observe_assist(
+        &mut self,
+        assist_id: InlineAssistId,
+    ) -> async_watch::Receiver<AssistStatus> {
         if let Some((_, rx)) = self.assist_observations.get(&assist_id) {
             rx.clone()
         } else {
-            let (tx, rx) = async_watch::channel(());
+            let (tx, rx) = async_watch::channel(AssistStatus::Idle);
             self.assist_observations.insert(assist_id, (tx, rx.clone()));
             rx
         }
@@ -2079,7 +2100,7 @@ impl InlineAssist {
                             if assist.decorations.is_none() {
                                 this.finish_assist(assist_id, false, cx);
                             } else if let Some(tx) = this.assist_observations.get(&assist_id) {
-                                tx.0.send(()).ok();
+                                tx.0.send(AssistStatus::Finished).ok();
                             }
                         }
                     })

crates/assistant/src/workflow.rs 🔗

@@ -23,6 +23,8 @@ use workspace::Workspace;
 
 pub use step_view::WorkflowStepView;
 
+const IMPORTS_SYMBOL: &str = "#imports";
+
 pub struct WorkflowStep {
     context: WeakModel<Context>,
     context_buffer_range: Range<Anchor>,
@@ -467,7 +469,7 @@ pub mod tool {
     use super::*;
     use anyhow::Context as _;
     use gpui::AsyncAppContext;
-    use language::ParseStatus;
+    use language::{Outline, OutlineItem, ParseStatus};
     use language_model::LanguageModelTool;
     use project::ProjectPath;
     use schemars::JsonSchema;
@@ -562,10 +564,7 @@ pub mod tool {
                     symbol,
                     description,
                 } => {
-                    let (symbol_path, symbol) = outline
-                        .find_most_similar(&symbol)
-                        .with_context(|| format!("symbol not found: {:?}", symbol))?;
-                    let symbol = symbol.to_point(&snapshot);
+                    let (symbol_path, symbol) = Self::resolve_symbol(&snapshot, &outline, &symbol)?;
                     let start = symbol
                         .annotation_range
                         .map_or(symbol.range.start, |range| range.start);
@@ -588,10 +587,7 @@ pub mod tool {
                     symbol,
                     description,
                 } => {
-                    let (symbol_path, symbol) = outline
-                        .find_most_similar(&symbol)
-                        .with_context(|| format!("symbol not found: {:?}", symbol))?;
-                    let symbol = symbol.to_point(&snapshot);
+                    let (symbol_path, symbol) = Self::resolve_symbol(&snapshot, &outline, &symbol)?;
                     let position = snapshot.anchor_before(
                         symbol
                             .annotation_range
@@ -609,10 +605,7 @@ pub mod tool {
                     symbol,
                     description,
                 } => {
-                    let (symbol_path, symbol) = outline
-                        .find_most_similar(&symbol)
-                        .with_context(|| format!("symbol not found: {:?}", symbol))?;
-                    let symbol = symbol.to_point(&snapshot);
+                    let (symbol_path, symbol) = Self::resolve_symbol(&snapshot, &outline, &symbol)?;
                     let position = snapshot.anchor_after(symbol.range.end);
                     WorkflowSuggestion::InsertSiblingAfter {
                         position,
@@ -625,10 +618,8 @@ pub mod tool {
                     description,
                 } => {
                     if let Some(symbol) = symbol {
-                        let (symbol_path, symbol) = outline
-                            .find_most_similar(&symbol)
-                            .with_context(|| format!("symbol not found: {:?}", symbol))?;
-                        let symbol = symbol.to_point(&snapshot);
+                        let (symbol_path, symbol) =
+                            Self::resolve_symbol(&snapshot, &outline, &symbol)?;
 
                         let position = snapshot.anchor_after(
                             symbol
@@ -653,10 +644,8 @@ pub mod tool {
                     description,
                 } => {
                     if let Some(symbol) = symbol {
-                        let (symbol_path, symbol) = outline
-                            .find_most_similar(&symbol)
-                            .with_context(|| format!("symbol not found: {:?}", symbol))?;
-                        let symbol = symbol.to_point(&snapshot);
+                        let (symbol_path, symbol) =
+                            Self::resolve_symbol(&snapshot, &outline, &symbol)?;
 
                         let position = snapshot.anchor_before(
                             symbol
@@ -677,10 +666,7 @@ pub mod tool {
                     }
                 }
                 WorkflowSuggestionToolKind::Delete { symbol } => {
-                    let (symbol_path, symbol) = outline
-                        .find_most_similar(&symbol)
-                        .with_context(|| format!("symbol not found: {:?}", symbol))?;
-                    let symbol = symbol.to_point(&snapshot);
+                    let (symbol_path, symbol) = Self::resolve_symbol(&snapshot, &outline, &symbol)?;
                     let start = symbol
                         .annotation_range
                         .map_or(symbol.range.start, |range| range.start);
@@ -696,6 +682,60 @@ pub mod tool {
 
             Ok((buffer, suggestion))
         }
+
+        fn resolve_symbol(
+            snapshot: &BufferSnapshot,
+            outline: &Outline<Anchor>,
+            symbol: &str,
+        ) -> Result<(SymbolPath, OutlineItem<Point>)> {
+            if symbol == IMPORTS_SYMBOL {
+                let target_row = find_first_non_comment_line(snapshot);
+                Ok((
+                    SymbolPath(IMPORTS_SYMBOL.to_string()),
+                    OutlineItem {
+                        range: Point::new(target_row, 0)..Point::new(target_row + 1, 0),
+                        ..Default::default()
+                    },
+                ))
+            } else {
+                let (symbol_path, symbol) = outline
+                    .find_most_similar(symbol)
+                    .with_context(|| format!("symbol not found: {symbol}"))?;
+                Ok((symbol_path, symbol.to_point(snapshot)))
+            }
+        }
+    }
+
+    fn find_first_non_comment_line(snapshot: &BufferSnapshot) -> u32 {
+        let Some(language) = snapshot.language() else {
+            return 0;
+        };
+
+        let scope = language.default_scope();
+        let comment_prefixes = scope.line_comment_prefixes();
+
+        let mut chunks = snapshot.as_rope().chunks();
+        let mut target_row = 0;
+        loop {
+            let starts_with_comment = chunks
+                .peek()
+                .map(|chunk| {
+                    comment_prefixes
+                        .iter()
+                        .any(|s| chunk.starts_with(s.as_ref().trim_end()))
+                })
+                .unwrap_or(false);
+
+            if !starts_with_comment {
+                break;
+            }
+
+            target_row += 1;
+            if !chunks.next_line() {
+                break;
+            }
+        }
+        target_row
     }
 
     #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, JsonSchema)]

crates/language/src/outline.rs 🔗

@@ -14,7 +14,7 @@ pub struct Outline<T> {
     path_candidate_prefixes: Vec<usize>,
 }
 
-#[derive(Clone, Debug, PartialEq, Eq, Hash)]
+#[derive(Clone, Debug, Default, PartialEq, Eq, Hash)]
 pub struct OutlineItem<T> {
     pub depth: usize,
     pub range: Range<T>,