edit_file_tool: Fail when edit location is not unique (#32056)

Oleksiy Syvokon created

When `<old_text>` points to more than one location in a file, we used to
edit the first match, confusing the agent along the way. Now we will
return an error, asking to expand `<old_text>` selection.

Closes #ISSUE

Release Notes:

- agent: Fixed incorrect file edits when edit locations are ambiguous

Change summary

crates/assistant_tools/src/edit_agent.rs                         | 134 +
crates/assistant_tools/src/edit_agent/streaming_fuzzy_matcher.rs | 148 +
crates/assistant_tools/src/edit_file_tool.rs                     |  13 
3 files changed, 214 insertions(+), 81 deletions(-)

Detailed changes

crates/assistant_tools/src/edit_agent.rs 🔗

@@ -54,6 +54,7 @@ impl Template for EditFilePromptTemplate {
 pub enum EditAgentOutputEvent {
     ResolvingEditRange(Range<Anchor>),
     UnresolvedEditRange,
+    AmbiguousEditRange(Vec<Range<usize>>),
     Edited,
 }
 
@@ -269,16 +270,29 @@ impl EditAgent {
                 }
             }
 
-            let (edit_events_, resolved_old_text) = resolve_old_text.await?;
+            let (edit_events_, mut resolved_old_text) = resolve_old_text.await?;
             edit_events = edit_events_;
 
             // If we can't resolve the old text, restart the loop waiting for a
             // new edit (or for the stream to end).
-            let Some(resolved_old_text) = resolved_old_text else {
-                output_events
-                    .unbounded_send(EditAgentOutputEvent::UnresolvedEditRange)
-                    .ok();
-                continue;
+            let resolved_old_text = match resolved_old_text.len() {
+                1 => resolved_old_text.pop().unwrap(),
+                0 => {
+                    output_events
+                        .unbounded_send(EditAgentOutputEvent::UnresolvedEditRange)
+                        .ok();
+                    continue;
+                }
+                _ => {
+                    let ranges = resolved_old_text
+                        .into_iter()
+                        .map(|text| text.range)
+                        .collect();
+                    output_events
+                        .unbounded_send(EditAgentOutputEvent::AmbiguousEditRange(ranges))
+                        .ok();
+                    continue;
+                }
             };
 
             // Compute edits in the background and apply them as they become
@@ -405,7 +419,7 @@ impl EditAgent {
         mut edit_events: T,
         cx: &mut AsyncApp,
     ) -> (
-        Task<Result<(T, Option<ResolvedOldText>)>>,
+        Task<Result<(T, Vec<ResolvedOldText>)>>,
         async_watch::Receiver<Option<Range<usize>>>,
     )
     where
@@ -425,21 +439,29 @@ impl EditAgent {
                 }
             }
 
-            let old_range = matcher.finish();
-            old_range_tx.send(old_range.clone())?;
-            if let Some(old_range) = old_range {
-                let line_indent =
-                    LineIndent::from_iter(matcher.query_lines().first().unwrap().chars());
-                Ok((
-                    edit_events,
-                    Some(ResolvedOldText {
-                        range: old_range,
-                        indent: line_indent,
-                    }),
-                ))
+            let matches = matcher.finish();
+
+            let old_range = if matches.len() == 1 {
+                matches.first()
             } else {
-                Ok((edit_events, None))
-            }
+                // No matches or multiple ambiguous matches
+                None
+            };
+            old_range_tx.send(old_range.cloned())?;
+
+            let indent = LineIndent::from_iter(
+                matcher
+                    .query_lines()
+                    .first()
+                    .unwrap_or(&String::new())
+                    .chars(),
+            );
+            let resolved_old_texts = matches
+                .into_iter()
+                .map(|range| ResolvedOldText { range, indent })
+                .collect::<Vec<_>>();
+
+            Ok((edit_events, resolved_old_texts))
         });
 
         (task, old_range_rx)
@@ -1322,6 +1344,76 @@ mod tests {
         EditAgent::new(model, project, action_log, Templates::new())
     }
 
+    #[gpui::test(iterations = 10)]
+    async fn test_non_unique_text_error(cx: &mut TestAppContext, mut rng: StdRng) {
+        let agent = init_test(cx).await;
+        let original_text = indoc! {"
+                function foo() {
+                    return 42;
+                }
+
+                function bar() {
+                    return 42;
+                }
+
+                function baz() {
+                    return 42;
+                }
+            "};
+        let buffer = cx.new(|cx| Buffer::local(original_text, cx));
+        let (apply, mut events) = agent.edit(
+            buffer.clone(),
+            String::new(),
+            &LanguageModelRequest::default(),
+            &mut cx.to_async(),
+        );
+        cx.run_until_parked();
+
+        // When <old_text> matches text in more than one place
+        simulate_llm_output(
+            &agent,
+            indoc! {"
+                <old_text>
+                return 42;
+                </old_text>
+                <new_text>
+                return 100;
+                </new_text>
+            "},
+            &mut rng,
+            cx,
+        );
+        apply.await.unwrap();
+
+        // Then the text should remain unchanged
+        let result_text = buffer.read_with(cx, |buffer, _| buffer.snapshot().text());
+        assert_eq!(
+            result_text,
+            indoc! {"
+                function foo() {
+                    return 42;
+                }
+
+                function bar() {
+                    return 42;
+                }
+
+                function baz() {
+                    return 42;
+                }
+            "},
+            "Text should remain unchanged when there are multiple matches"
+        );
+
+        // And AmbiguousEditRange even should be emitted
+        let events = drain_events(&mut events);
+        let ambiguous_ranges = vec![17..31, 52..66, 87..101];
+        assert!(
+            events.contains(&EditAgentOutputEvent::AmbiguousEditRange(ambiguous_ranges)),
+            "Should emit AmbiguousEditRange for non-unique text"
+        );
+    }
+
     fn drain_events(
         stream: &mut UnboundedReceiver<EditAgentOutputEvent>,
     ) -> Vec<EditAgentOutputEvent> {

crates/assistant_tools/src/edit_agent/streaming_fuzzy_matcher.rs 🔗

@@ -11,7 +11,7 @@ pub struct StreamingFuzzyMatcher {
     snapshot: TextBufferSnapshot,
     query_lines: Vec<String>,
     incomplete_line: String,
-    best_match: Option<Range<usize>>,
+    best_matches: Vec<Range<usize>>,
     matrix: SearchMatrix,
 }
 
@@ -22,7 +22,7 @@ impl StreamingFuzzyMatcher {
             snapshot,
             query_lines: Vec::new(),
             incomplete_line: String::new(),
-            best_match: None,
+            best_matches: Vec::new(),
             matrix: SearchMatrix::new(buffer_line_count + 1),
         }
     }
@@ -55,31 +55,41 @@ impl StreamingFuzzyMatcher {
 
             self.incomplete_line.replace_range(..last_pos + 1, "");
 
-            self.best_match = self.resolve_location_fuzzy();
-        }
+            self.best_matches = self.resolve_location_fuzzy();
 
-        self.best_match.clone()
+            if let Some(first_match) = self.best_matches.first() {
+                Some(first_match.clone())
+            } else {
+                None
+            }
+        } else {
+            if let Some(first_match) = self.best_matches.first() {
+                Some(first_match.clone())
+            } else {
+                None
+            }
+        }
     }
 
-    /// Finish processing and return the final best match.
+    /// Finish processing and return the final best match(es).
     ///
     /// This processes any remaining incomplete line before returning the final
     /// match result.
-    pub fn finish(&mut self) -> Option<Range<usize>> {
+    pub fn finish(&mut self) -> Vec<Range<usize>> {
         // Process any remaining incomplete line
         if !self.incomplete_line.is_empty() {
             self.query_lines.push(self.incomplete_line.clone());
-            self.best_match = self.resolve_location_fuzzy();
+            self.incomplete_line.clear();
+            self.best_matches = self.resolve_location_fuzzy();
         }
-
-        self.best_match.clone()
+        self.best_matches.clone()
     }
 
-    fn resolve_location_fuzzy(&mut self) -> Option<Range<usize>> {
+    fn resolve_location_fuzzy(&mut self) -> Vec<Range<usize>> {
         let new_query_line_count = self.query_lines.len();
         let old_query_line_count = self.matrix.rows.saturating_sub(1);
         if new_query_line_count == old_query_line_count {
-            return None;
+            return Vec::new();
         }
 
         self.matrix.resize_rows(new_query_line_count + 1);
@@ -132,53 +142,61 @@ impl StreamingFuzzyMatcher {
             }
         }
 
-        // Traceback to find the best match
+        // Find all matches with the best cost
         let buffer_line_count = self.snapshot.max_point().row as usize + 1;
-        let mut buffer_row_end = buffer_line_count as u32;
         let mut best_cost = u32::MAX;
+        let mut matches_with_best_cost = Vec::new();
+
         for col in 1..=buffer_line_count {
             let cost = self.matrix.get(new_query_line_count, col).cost;
             if cost < best_cost {
                 best_cost = cost;
-                buffer_row_end = col as u32;
+                matches_with_best_cost.clear();
+                matches_with_best_cost.push(col as u32);
+            } else if cost == best_cost {
+                matches_with_best_cost.push(col as u32);
             }
         }
 
-        let mut matched_lines = 0;
-        let mut query_row = new_query_line_count;
-        let mut buffer_row_start = buffer_row_end;
-        while query_row > 0 && buffer_row_start > 0 {
-            let current = self.matrix.get(query_row, buffer_row_start as usize);
-            match current.direction {
-                SearchDirection::Diagonal => {
-                    query_row -= 1;
-                    buffer_row_start -= 1;
-                    matched_lines += 1;
-                }
-                SearchDirection::Up => {
-                    query_row -= 1;
-                }
-                SearchDirection::Left => {
-                    buffer_row_start -= 1;
+        // Find ranges for the matches
+        let mut valid_matches = Vec::new();
+        for &buffer_row_end in &matches_with_best_cost {
+            let mut matched_lines = 0;
+            let mut query_row = new_query_line_count;
+            let mut buffer_row_start = buffer_row_end;
+            while query_row > 0 && buffer_row_start > 0 {
+                let current = self.matrix.get(query_row, buffer_row_start as usize);
+                match current.direction {
+                    SearchDirection::Diagonal => {
+                        query_row -= 1;
+                        buffer_row_start -= 1;
+                        matched_lines += 1;
+                    }
+                    SearchDirection::Up => {
+                        query_row -= 1;
+                    }
+                    SearchDirection::Left => {
+                        buffer_row_start -= 1;
+                    }
                 }
             }
-        }
 
-        let matched_buffer_row_count = buffer_row_end - buffer_row_start;
-        let matched_ratio = matched_lines as f32
-            / (matched_buffer_row_count as f32).max(new_query_line_count as f32);
-        if matched_ratio >= 0.8 {
-            let buffer_start_ix = self
-                .snapshot
-                .point_to_offset(Point::new(buffer_row_start, 0));
-            let buffer_end_ix = self.snapshot.point_to_offset(Point::new(
-                buffer_row_end - 1,
-                self.snapshot.line_len(buffer_row_end - 1),
-            ));
-            Some(buffer_start_ix..buffer_end_ix)
-        } else {
-            None
+            let matched_buffer_row_count = buffer_row_end - buffer_row_start;
+            let matched_ratio = matched_lines as f32
+                / (matched_buffer_row_count as f32).max(new_query_line_count as f32);
+            if matched_ratio >= 0.8 {
+                let buffer_start_ix = self
+                    .snapshot
+                    .point_to_offset(Point::new(buffer_row_start, 0));
+                let buffer_end_ix = self.snapshot.point_to_offset(Point::new(
+                    buffer_row_end - 1,
+                    self.snapshot.line_len(buffer_row_end - 1),
+                ));
+                valid_matches.push((buffer_row_start, buffer_start_ix..buffer_end_ix));
+            }
         }
+
+        valid_matches.into_iter().map(|(_, range)| range).collect()
     }
 }
 
@@ -638,28 +656,35 @@ mod tests {
             matcher.push(chunk);
         }
 
-        let result = matcher.finish();
+        let actual_ranges = matcher.finish();
 
         // If no expected ranges, we expect no match
         if expected_ranges.is_empty() {
-            assert_eq!(
-                result, None,
+            assert!(
+                actual_ranges.is_empty(),
                 "Expected no match for query: {:?}, but found: {:?}",
-                query, result
+                query,
+                actual_ranges
             );
         } else {
-            let mut actual_ranges = Vec::new();
-            if let Some(range) = result {
-                actual_ranges.push(range);
-            }
-
             let text_with_actual_range = generate_marked_text(&text, &actual_ranges, false);
             pretty_assertions::assert_eq!(
                 text_with_actual_range,
                 text_with_expected_range,
-                "Query: {:?}, Chunks: {:?}",
+                indoc! {"
+                    Query: {:?}
+                    Chunks: {:?}
+                    Expected marked text: {}
+                    Actual marked text: {}
+                    Expected ranges: {:?}
+                    Actual ranges: {:?}"
+                },
                 query,
-                chunks
+                chunks,
+                text_with_expected_range,
+                text_with_actual_range,
+                expected_ranges,
+                actual_ranges
             );
         }
     }
@@ -687,8 +712,11 @@ mod tests {
 
     fn finish(mut finder: StreamingFuzzyMatcher) -> Option<String> {
         let snapshot = finder.snapshot.clone();
-        finder
-            .finish()
-            .map(|range| snapshot.text_for_range(range).collect::<String>())
+        let matches = finder.finish();
+        if let Some(range) = matches.first() {
+            Some(snapshot.text_for_range(range.clone()).collect::<String>())
+        } else {
+            None
+        }
     }
 }

crates/assistant_tools/src/edit_file_tool.rs 🔗

@@ -239,6 +239,7 @@ impl Tool for EditFileTool {
             };
 
             let mut hallucinated_old_text = false;
+            let mut ambiguous_ranges = Vec::new();
             while let Some(event) = events.next().await {
                 match event {
                     EditAgentOutputEvent::Edited => {
@@ -247,6 +248,7 @@ impl Tool for EditFileTool {
                         }
                     }
                     EditAgentOutputEvent::UnresolvedEditRange => hallucinated_old_text = true,
+                    EditAgentOutputEvent::AmbiguousEditRange(ranges) => ambiguous_ranges = ranges,
                     EditAgentOutputEvent::ResolvingEditRange(range) => {
                         if let Some(card) = card_clone.as_ref() {
                             card.update(cx, |card, cx| card.reveal_range(range, cx))?;
@@ -329,6 +331,17 @@ impl Tool for EditFileTool {
                         I can perform the requested edits.
                     "}
                 );
+                anyhow::ensure!(
+                    ambiguous_ranges.is_empty(),
+                    // TODO: Include ambiguous_ranges, converted to line numbers.
+                    //       This would work best if we add `line_hint` parameter
+                    //       to edit_file_tool
+                    formatdoc! {"
+                        <old_text> matches more than one position in the file. Read the
+                        relevant sections of {input_path} again and extend <old_text> so
+                        that I can perform the requested edits.
+                    "}
+                );
                 Ok(ToolResultOutput {
                     content: ToolResultContent::Text("No edits were made.".into()),
                     output: serde_json::to_value(output).ok(),