Simple call site snippet test

Agus Zubiaga created

Change summary

crates/edit_prediction_context/src/excerpt.rs            |   2 
crates/edit_prediction_context/src/scored_declaration.rs | 182 +++++++++
2 files changed, 174 insertions(+), 10 deletions(-)

Detailed changes

crates/edit_prediction_context/src/excerpt.rs 🔗

@@ -31,7 +31,7 @@ pub struct EditPredictionExcerptOptions {
     pub include_parent_signatures: bool,
 }
 
-#[derive(Clone)]
+#[derive(Debug, Clone)]
 pub struct EditPredictionExcerpt {
     pub range: Range<usize>,
     pub parent_signature_ranges: Vec<Range<usize>>,

crates/edit_prediction_context/src/scored_declaration.rs 🔗

@@ -53,7 +53,7 @@ fn scored_snippets(
     index: Entity<TreeSitterIndex>,
     excerpt: &EditPredictionExcerpt,
     excerpt_text: &EditPredictionExcerptText,
-    references: Vec<Reference>,
+    identifier_to_references: HashMap<Identifier, Vec<Reference>>,
     cursor_offset: usize,
     current_buffer: &BufferSnapshot,
     cx: &App,
@@ -74,14 +74,6 @@ fn scored_snippets(
             .collect::<String>(),
     );
 
-    let mut identifier_to_references: HashMap<Identifier, Vec<Reference>> = HashMap::new();
-    for reference in references {
-        identifier_to_references
-            .entry(reference.identifier.clone())
-            .or_insert_with(Vec::new)
-            .push(reference);
-    }
-
     identifier_to_references
         .into_iter()
         .flat_map(|(identifier, references)| {
@@ -326,3 +318,175 @@ impl ScoreInputs {
         }
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use std::sync::Arc;
+
+    use gpui::{TestAppContext, prelude::*};
+    use indoc::indoc;
+    use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
+    use project::{FakeFs, Project};
+    use serde_json::json;
+    use settings::SettingsStore;
+    use text::ToOffset;
+    use util::path;
+
+    use crate::{
+        EditPredictionExcerptOptions, references_in_excerpt, tree_sitter_index::TreeSitterIndex,
+    };
+
+    #[gpui::test]
+    async fn test_call_site(cx: &mut TestAppContext) {
+        let (project, index, _rust_lang_id) = init_test(cx).await;
+
+        let buffer = project
+            .update(cx, |project, cx| {
+                let project_path = project.find_project_path("c.rs", cx).unwrap();
+                project.open_buffer(project_path, cx)
+            })
+            .await
+            .unwrap();
+
+        cx.run_until_parked();
+
+        // first process_data call site
+        let cursor_point = language::Point::new(8, 21);
+        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
+        let excerpt = EditPredictionExcerpt::select_from_buffer(
+            cursor_point,
+            &buffer_snapshot,
+            &EditPredictionExcerptOptions {
+                max_bytes: 40,
+                min_bytes: 10,
+                target_before_cursor_over_total_bytes: 0.5,
+                include_parent_signatures: false,
+            },
+        )
+        .unwrap();
+        let excerpt_text = excerpt.text(&buffer_snapshot);
+        let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer_snapshot);
+        let cursor_offset = cursor_point.to_offset(&buffer_snapshot);
+
+        let snippets = cx.update(|cx| {
+            scored_snippets(
+                index,
+                &excerpt,
+                &excerpt_text,
+                references,
+                cursor_offset,
+                &buffer_snapshot,
+                cx,
+            )
+        });
+
+        assert_eq!(snippets.len(), 1);
+        assert_eq!(snippets[0].identifier.name.as_ref(), "process_data");
+        drop(buffer);
+    }
+
+    async fn init_test(
+        cx: &mut TestAppContext,
+    ) -> (Entity<Project>, Entity<TreeSitterIndex>, LanguageId) {
+        cx.update(|cx| {
+            let settings_store = SettingsStore::test(cx);
+            cx.set_global(settings_store);
+            language::init(cx);
+            Project::init_settings(cx);
+        });
+
+        let fs = FakeFs::new(cx.executor());
+        fs.insert_tree(
+            path!("/root"),
+            json!({
+                "a.rs": indoc! {r#"
+                    fn main() {
+                        let x = 1;
+                        let y = 2;
+                        let z = add(x, y);
+                        println!("Result: {}", z);
+                    }
+
+                    fn add(a: i32, b: i32) -> i32 {
+                        a + b
+                    }
+                "#},
+                "b.rs": indoc! {"
+                    pub struct Config {
+                        pub name: String,
+                        pub value: i32,
+                    }
+
+                    impl Config {
+                        pub fn new(name: String, value: i32) -> Self {
+                            Config { name, value }
+                        }
+                    }
+                "},
+                "c.rs": indoc! {r#"
+                    use std::collections::HashMap;
+
+                    fn main() {
+                        let args: Vec<String> = std::env::args().collect();
+                        let data: Vec<i32> = args[1..]
+                            .iter()
+                            .filter_map(|s| s.parse().ok())
+                            .collect();
+                        let result = process_data(data);
+                        println!("{:?}", result);
+                    }
+
+                    fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
+                        let mut counts = HashMap::new();
+                        for value in data {
+                            *counts.entry(value).or_insert(0) += 1;
+                        }
+                        counts
+                    }
+
+                    #[cfg(test)]
+                    mod tests {
+                        use super::*;
+
+                        #[test]
+                        fn test_process_data() {
+                            let data = vec![1, 2, 2, 3];
+                            let result = process_data(data);
+                            assert_eq!(result.get(&2), Some(&2));
+                        }
+                    }
+                "#}
+            }),
+        )
+        .await;
+        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
+        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
+        let lang = rust_lang();
+        let lang_id = lang.id();
+        language_registry.add(Arc::new(lang));
+
+        let index = cx.new(|cx| TreeSitterIndex::new(&project, cx));
+        cx.run_until_parked();
+
+        (project, index, lang_id)
+    }
+
+    fn rust_lang() -> Language {
+        Language::new(
+            LanguageConfig {
+                name: "Rust".into(),
+                matcher: LanguageMatcher {
+                    path_suffixes: vec!["rs".to_string()],
+                    ..Default::default()
+                },
+                ..Default::default()
+            },
+            Some(tree_sitter_rust::LANGUAGE.into()),
+        )
+        .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
+        .unwrap()
+        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
+        .unwrap()
+    }
+}