Fix infinite loop in assemble_excerpts (#44195)

Max Brunsfeld created

Also, expand the number of identifiers fetched.

Release Notes:

- N/A

Change summary

crates/edit_prediction/src/edit_prediction.rs                       |   8 
crates/edit_prediction_context/src/assemble_excerpts.rs             | 165 
crates/edit_prediction_context/src/edit_prediction_context.rs       |  33 
crates/edit_prediction_context/src/edit_prediction_context_tests.rs | 174 
4 files changed, 204 insertions(+), 176 deletions(-)

Detailed changes

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -480,16 +480,16 @@ impl EditPredictionStore {
             shown_predictions: Default::default(),
         };
 
-        this.enable_or_disable_context_retrieval(cx);
+        this.configure_context_retrieval(cx);
         let weak_this = cx.weak_entity();
         cx.on_flags_ready(move |_, cx| {
             weak_this
-                .update(cx, |this, cx| this.enable_or_disable_context_retrieval(cx))
+                .update(cx, |this, cx| this.configure_context_retrieval(cx))
                 .ok();
         })
         .detach();
         cx.observe_global::<SettingsStore>(|this, cx| {
-            this.enable_or_disable_context_retrieval(cx);
+            this.configure_context_retrieval(cx);
         })
         .detach();
 
@@ -1770,7 +1770,7 @@ impl EditPredictionStore {
         cx.notify();
     }
 
-    fn enable_or_disable_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
+    fn configure_context_retrieval(&mut self, cx: &mut Context<'_, EditPredictionStore>) {
         self.use_context = cx.has_flag::<Zeta2FeatureFlag>()
             && all_language_settings(None, cx).edit_predictions.use_context;
     }

crates/edit_prediction_context/src/assemble_excerpts.rs 🔗

@@ -61,8 +61,8 @@ pub fn assemble_excerpts(
                                 buffer,
                                 &mut outline_ranges,
                             );
-                            child_outline_ix += 1;
                         }
+                        child_outline_ix += 1;
                     }
                 }
             }
@@ -159,166 +159,3 @@ pub fn merge_ranges(ranges: &mut Vec<Range<Point>>) {
         }
     }
 }
-
-#[cfg(test)]
-mod tests {
-    use super::*;
-    use gpui::{TestAppContext, prelude::*};
-    use indoc::indoc;
-    use language::{Buffer, Language, LanguageConfig, LanguageMatcher, OffsetRangeExt};
-    use pretty_assertions::assert_eq;
-    use std::{fmt::Write as _, sync::Arc};
-    use util::test::marked_text_ranges;
-
-    #[gpui::test]
-    fn test_rust(cx: &mut TestAppContext) {
-        let table = [
-            (
-                indoc! {r#"
-                    struct User {
-                        first_name: String,
-                        «last_name»: String,
-                        age: u32,
-                        email: String,
-                        create_at: Instant,
-                    }
-
-                    impl User {
-                        pub fn first_name(&self) -> String {
-                            self.first_name.clone()
-                        }
-
-                        pub fn full_name(&self) -> String {
-                    «        format!("{} {}", self.first_name, self.last_name)
-                    »    }
-                    }
-                "#},
-                indoc! {r#"
-                    struct User {
-                        first_name: String,
-                        last_name: String,
-                    …
-                    }
-
-                    impl User {
-                    …
-                        pub fn full_name(&self) -> String {
-                            format!("{} {}", self.first_name, self.last_name)
-                        }
-                    }
-                "#},
-            ),
-            (
-                indoc! {r#"
-                    struct «User» {
-                        first_name: String,
-                        last_name: String,
-                        age: u32,
-                    }
-
-                    impl User {
-                        // methods
-                    }
-                    "#
-                },
-                indoc! {r#"
-                    struct User {
-                        first_name: String,
-                        last_name: String,
-                        age: u32,
-                    }
-                    …
-                "#},
-            ),
-            (
-                indoc! {r#"
-                    trait «FooProvider» {
-                        const NAME: &'static str;
-
-                        fn provide_foo(&self, id: usize) -> Foo;
-
-                        fn provide_foo_batched(&self, ids: &[usize]) -> Vec<Foo> {
-                             ids.iter()
-                                .map(|id| self.provide_foo(*id))
-                                .collect()
-                        }
-
-                        fn sync(&self);
-                    }
-                    "#
-                },
-                indoc! {r#"
-                    trait FooProvider {
-                        const NAME: &'static str;
-
-                        fn provide_foo(&self, id: usize) -> Foo;
-
-                        fn provide_foo_batched(&self, ids: &[usize]) -> Vec<Foo> {
-                    …
-                        }
-
-                        fn sync(&self);
-                    }
-                "#},
-            ),
-        ];
-
-        for (input, expected_output) in table {
-            let (input, ranges) = marked_text_ranges(&input, false);
-            let buffer =
-                cx.new(|cx| Buffer::local(input, cx).with_language(Arc::new(rust_lang()), cx));
-            buffer.read_with(cx, |buffer, _cx| {
-                let ranges: Vec<Range<Point>> = ranges
-                    .into_iter()
-                    .map(|range| range.to_point(&buffer))
-                    .collect();
-
-                let excerpts = assemble_excerpts(&buffer.snapshot(), ranges);
-
-                let output = format_excerpts(buffer, &excerpts);
-                assert_eq!(output, expected_output);
-            });
-        }
-    }
-
-    fn format_excerpts(buffer: &Buffer, excerpts: &[RelatedExcerpt]) -> String {
-        let mut output = String::new();
-        let file_line_count = buffer.max_point().row;
-        let mut current_row = 0;
-        for excerpt in excerpts {
-            if excerpt.text.is_empty() {
-                continue;
-            }
-            if current_row < excerpt.point_range.start.row {
-                writeln!(&mut output, "…").unwrap();
-            }
-            current_row = excerpt.point_range.start.row;
-
-            for line in excerpt.text.to_string().lines() {
-                output.push_str(line);
-                output.push('\n');
-                current_row += 1;
-            }
-        }
-        if current_row < file_line_count {
-            writeln!(&mut output, "…").unwrap();
-        }
-        output
-    }
-
-    fn rust_lang() -> Language {
-        Language::new(
-            LanguageConfig {
-                name: "Rust".into(),
-                matcher: LanguageMatcher {
-                    path_suffixes: vec!["rs".to_string()],
-                    ..Default::default()
-                },
-                ..Default::default()
-            },
-            Some(language::tree_sitter_rust::LANGUAGE.into()),
-        )
-        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
-        .unwrap()
-    }
-}

crates/edit_prediction_context/src/edit_prediction_context.rs 🔗

@@ -25,11 +25,14 @@ mod fake_definition_lsp;
 pub use cloud_llm_client::predict_edits_v3::Line;
 pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
 
+const IDENTIFIER_LINE_COUNT: u32 = 3;
+
 pub struct RelatedExcerptStore {
     project: WeakEntity<Project>,
     related_files: Vec<RelatedFile>,
     cache: HashMap<Identifier, Arc<CacheEntry>>,
     update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
+    identifier_line_count: u32,
 }
 
 pub enum RelatedExcerptStoreEvent {
@@ -178,9 +181,14 @@ impl RelatedExcerptStore {
             update_tx,
             related_files: Vec::new(),
             cache: Default::default(),
+            identifier_line_count: IDENTIFIER_LINE_COUNT,
         }
     }
 
+    pub fn set_identifier_line_count(&mut self, count: u32) {
+        self.identifier_line_count = count;
+    }
+
     pub fn refresh(&mut self, buffer: Entity<Buffer>, position: Anchor, _: &mut Context<Self>) {
         self.update_tx.unbounded_send((buffer, position)).ok();
     }
@@ -195,8 +203,12 @@ impl RelatedExcerptStore {
         position: Anchor,
         cx: &mut AsyncApp,
     ) -> Result<()> {
-        let (project, snapshot) = this.read_with(cx, |this, cx| {
-            (this.project.upgrade(), buffer.read(cx).snapshot())
+        let (project, snapshot, identifier_line_count) = this.read_with(cx, |this, cx| {
+            (
+                this.project.upgrade(),
+                buffer.read(cx).snapshot(),
+                this.identifier_line_count,
+            )
         })?;
         let Some(project) = project else {
             return Ok(());
@@ -212,7 +224,9 @@ impl RelatedExcerptStore {
         })?;
 
         let identifiers = cx
-            .background_spawn(async move { identifiers_for_position(&snapshot, position) })
+            .background_spawn(async move {
+                identifiers_for_position(&snapshot, position, identifier_line_count)
+            })
             .await;
 
         let async_cx = cx.clone();
@@ -393,14 +407,21 @@ fn process_definition(
 
 /// Gets all of the identifiers that are present in the given line, and its containing
 /// outline items.
-fn identifiers_for_position(buffer: &BufferSnapshot, position: Anchor) -> Vec<Identifier> {
+fn identifiers_for_position(
+    buffer: &BufferSnapshot,
+    position: Anchor,
+    identifier_line_count: u32,
+) -> Vec<Identifier> {
     let offset = position.to_offset(buffer);
     let point = buffer.offset_to_point(offset);
 
-    let line_range = Point::new(point.row, 0)..Point::new(point.row + 1, 0).min(buffer.max_point());
+    // Search for identifiers on lines adjacent to the cursor.
+    let start = Point::new(point.row.saturating_sub(identifier_line_count), 0);
+    let end = Point::new(point.row + identifier_line_count + 1, 0).min(buffer.max_point());
+    let line_range = start..end;
     let mut ranges = vec![line_range.to_offset(&buffer)];
 
-    // Include the range of the outline item itself, but not its body.
+    // Search for identifiers mentioned in headers/signatures of containing outline items.
     let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None);
     for item in outline_items {
         if let Some(body_range) = item.body_range(&buffer) {

crates/edit_prediction_context/src/edit_prediction_context_tests.rs 🔗

@@ -7,8 +7,8 @@ use lsp::FakeLanguageServer;
 use project::{FakeFs, LocationLink, Project};
 use serde_json::json;
 use settings::SettingsStore;
-use std::sync::Arc;
-use util::path;
+use std::{fmt::Write as _, sync::Arc};
+use util::{path, test::marked_text_ranges};
 
 #[gpui::test]
 async fn test_edit_prediction_context(cx: &mut TestAppContext) {
@@ -37,6 +37,7 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
             buffer.anchor_before(offset)
         };
 
+        store.set_identifier_line_count(0);
         store.refresh(buffer.clone(), position, cx);
     });
 
@@ -85,6 +86,150 @@ async fn test_edit_prediction_context(cx: &mut TestAppContext) {
     });
 }
 
+#[gpui::test]
+fn test_assemble_excerpts(cx: &mut TestAppContext) {
+    let table = [
+        (
+            indoc! {r#"
+                struct User {
+                    first_name: String,
+                    «last_name»: String,
+                    age: u32,
+                    email: String,
+                    create_at: Instant,
+                }
+
+                impl User {
+                    pub fn first_name(&self) -> String {
+                        self.first_name.clone()
+                    }
+
+                    pub fn full_name(&self) -> String {
+                «        format!("{} {}", self.first_name, self.last_name)
+                »    }
+                }
+            "#},
+            indoc! {r#"
+                struct User {
+                    first_name: String,
+                    last_name: String,
+                …
+                }
+
+                impl User {
+                …
+                    pub fn full_name(&self) -> String {
+                        format!("{} {}", self.first_name, self.last_name)
+                    }
+                }
+            "#},
+        ),
+        (
+            indoc! {r#"
+                struct «User» {
+                    first_name: String,
+                    last_name: String,
+                    age: u32,
+                }
+
+                impl User {
+                    // methods
+                }
+            "#},
+            indoc! {r#"
+                struct User {
+                    first_name: String,
+                    last_name: String,
+                    age: u32,
+                }
+                …
+            "#},
+        ),
+        (
+            indoc! {r#"
+                trait «FooProvider» {
+                    const NAME: &'static str;
+
+                    fn provide_foo(&self, id: usize) -> Foo;
+
+                    fn provide_foo_batched(&self, ids: &[usize]) -> Vec<Foo> {
+                            ids.iter()
+                            .map(|id| self.provide_foo(*id))
+                            .collect()
+                    }
+
+                    fn sync(&self);
+                }
+                "#
+            },
+            indoc! {r#"
+                trait FooProvider {
+                    const NAME: &'static str;
+
+                    fn provide_foo(&self, id: usize) -> Foo;
+
+                    fn provide_foo_batched(&self, ids: &[usize]) -> Vec<Foo> {
+                …
+                    }
+
+                    fn sync(&self);
+                }
+            "#},
+        ),
+        (
+            indoc! {r#"
+                trait «Something» {
+                    fn method1(&self, id: usize) -> Foo;
+
+                    fn method2(&self, ids: &[usize]) -> Vec<Foo> {
+                            struct Helper1 {
+                            field1: usize,
+                            }
+
+                            struct Helper2 {
+                            field2: usize,
+                            }
+
+                            struct Helper3 {
+                            filed2: usize,
+                        }
+                    }
+
+                    fn sync(&self);
+                }
+                "#
+            },
+            indoc! {r#"
+                trait Something {
+                    fn method1(&self, id: usize) -> Foo;
+
+                    fn method2(&self, ids: &[usize]) -> Vec<Foo> {
+                …
+                    }
+
+                    fn sync(&self);
+                }
+            "#},
+        ),
+    ];
+
+    for (input, expected_output) in table {
+        let (input, ranges) = marked_text_ranges(&input, false);
+        let buffer = cx.new(|cx| Buffer::local(input, cx).with_language(rust_lang(), cx));
+        buffer.read_with(cx, |buffer, _cx| {
+            let ranges: Vec<Range<Point>> = ranges
+                .into_iter()
+                .map(|range| range.to_point(&buffer))
+                .collect();
+
+            let excerpts = assemble_excerpts(&buffer.snapshot(), ranges);
+
+            let output = format_excerpts(buffer, &excerpts);
+            assert_eq!(output, expected_output);
+        });
+    }
+}
+
 #[gpui::test]
 async fn test_fake_definition_lsp(cx: &mut TestAppContext) {
     init_test(cx);
@@ -339,6 +484,31 @@ fn assert_definitions(definitions: &[LocationLink], first_lines: &[&str], cx: &m
     assert_eq!(actual_first_lines, first_lines);
 }
 
+fn format_excerpts(buffer: &Buffer, excerpts: &[RelatedExcerpt]) -> String {
+    let mut output = String::new();
+    let file_line_count = buffer.max_point().row;
+    let mut current_row = 0;
+    for excerpt in excerpts {
+        if excerpt.text.is_empty() {
+            continue;
+        }
+        if current_row < excerpt.point_range.start.row {
+            writeln!(&mut output, "…").unwrap();
+        }
+        current_row = excerpt.point_range.start.row;
+
+        for line in excerpt.text.to_string().lines() {
+            output.push_str(line);
+            output.push('\n');
+            current_row += 1;
+        }
+    }
+    if current_row < file_line_count {
+        writeln!(&mut output, "…").unwrap();
+    }
+    output
+}
+
 pub(crate) fn rust_lang() -> Arc<Language> {
     Arc::new(
         Language::new(