diff --git a/crates/edit_prediction_context/src/excerpt.rs b/crates/edit_prediction_context/src/excerpt.rs index c6caa6a1b7b4076cf739c1ac198656b9fba431a6..da1de042623167d17f078c1e85b461fb0ecc8c24 100644 --- a/crates/edit_prediction_context/src/excerpt.rs +++ b/crates/edit_prediction_context/src/excerpt.rs @@ -31,7 +31,7 @@ pub struct EditPredictionExcerptOptions { pub include_parent_signatures: bool, } -#[derive(Clone)] +#[derive(Debug, Clone)] pub struct EditPredictionExcerpt { pub range: Range, pub parent_signature_ranges: Vec>, diff --git a/crates/edit_prediction_context/src/scored_declaration.rs b/crates/edit_prediction_context/src/scored_declaration.rs index 56629a0505a6bae1e32aec158ba4a2985ef803d7..14198e453c8c5a09e05a46cb2db8d7c230f60cd3 100644 --- a/crates/edit_prediction_context/src/scored_declaration.rs +++ b/crates/edit_prediction_context/src/scored_declaration.rs @@ -53,7 +53,7 @@ fn scored_snippets( index: Entity, excerpt: &EditPredictionExcerpt, excerpt_text: &EditPredictionExcerptText, - references: Vec, + identifier_to_references: HashMap>, cursor_offset: usize, current_buffer: &BufferSnapshot, cx: &App, @@ -74,14 +74,6 @@ fn scored_snippets( .collect::(), ); - let mut identifier_to_references: HashMap> = HashMap::new(); - for reference in references { - identifier_to_references - .entry(reference.identifier.clone()) - .or_insert_with(Vec::new) - .push(reference); - } - identifier_to_references .into_iter() .flat_map(|(identifier, references)| { @@ -326,3 +318,175 @@ impl ScoreInputs { } } } + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + use gpui::{TestAppContext, prelude::*}; + use indoc::indoc; + use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust}; + use project::{FakeFs, Project}; + use serde_json::json; + use settings::SettingsStore; + use text::ToOffset; + use util::path; + + use crate::{ + EditPredictionExcerptOptions, references_in_excerpt, tree_sitter_index::TreeSitterIndex, + }; + + #[gpui::test] + async fn test_call_site(cx: &mut TestAppContext) { + let (project, index, _rust_lang_id) = init_test(cx).await; + + let buffer = project + .update(cx, |project, cx| { + let project_path = project.find_project_path("c.rs", cx).unwrap(); + project.open_buffer(project_path, cx) + }) + .await + .unwrap(); + + cx.run_until_parked(); + + // first process_data call site + let cursor_point = language::Point::new(8, 21); + let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); + let excerpt = EditPredictionExcerpt::select_from_buffer( + cursor_point, + &buffer_snapshot, + &EditPredictionExcerptOptions { + max_bytes: 40, + min_bytes: 10, + target_before_cursor_over_total_bytes: 0.5, + include_parent_signatures: false, + }, + ) + .unwrap(); + let excerpt_text = excerpt.text(&buffer_snapshot); + let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer_snapshot); + let cursor_offset = cursor_point.to_offset(&buffer_snapshot); + + let snippets = cx.update(|cx| { + scored_snippets( + index, + &excerpt, + &excerpt_text, + references, + cursor_offset, + &buffer_snapshot, + cx, + ) + }); + + assert_eq!(snippets.len(), 1); + assert_eq!(snippets[0].identifier.name.as_ref(), "process_data"); + drop(buffer); + } + + async fn init_test( + cx: &mut TestAppContext, + ) -> (Entity, Entity, LanguageId) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/root"), + json!({ + "a.rs": indoc! {r#" + fn main() { + let x = 1; + let y = 2; + let z = add(x, y); + println!("Result: {}", z); + } + + fn add(a: i32, b: i32) -> i32 { + a + b + } + "#}, + "b.rs": indoc! {" + pub struct Config { + pub name: String, + pub value: i32, + } + + impl Config { + pub fn new(name: String, value: i32) -> Self { + Config { name, value } + } + } + "}, + "c.rs": indoc! {r#" + use std::collections::HashMap; + + fn main() { + let args: Vec = std::env::args().collect(); + let data: Vec = args[1..] + .iter() + .filter_map(|s| s.parse().ok()) + .collect(); + let result = process_data(data); + println!("{:?}", result); + } + + fn process_data(data: Vec) -> HashMap { + let mut counts = HashMap::new(); + for value in data { + *counts.entry(value).or_insert(0) += 1; + } + counts + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn test_process_data() { + let data = vec![1, 2, 2, 3]; + let result = process_data(data); + assert_eq!(result.get(&2), Some(&2)); + } + } + "#} + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _| project.languages().clone()); + let lang = rust_lang(); + let lang_id = lang.id(); + language_registry.add(Arc::new(lang)); + + let index = cx.new(|cx| TreeSitterIndex::new(&project, cx)); + cx.run_until_parked(); + + (project, index, lang_id) + } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm")) + .unwrap() + .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) + .unwrap() + } +}