diff --git a/crates/edit_prediction_context/Cargo.toml b/crates/edit_prediction_context/Cargo.toml index 29755a3b0eecdb3eacf5f8d4b0b0dba69135d4b5..912775d5045a74fd26eaf445e20afef336f40f54 100644 --- a/crates/edit_prediction_context/Cargo.toml +++ b/crates/edit_prediction_context/Cargo.toml @@ -19,7 +19,9 @@ path = "examples/zeta_context.rs" anyhow.workspace = true arrayvec.workspace = true collections.workspace = true +futures.workspace = true gpui.workspace = true +itertools.workspace = true language.workspace = true log.workspace = true project.workspace = true @@ -31,7 +33,6 @@ text.workspace = true tree-sitter.workspace = true util.workspace = true workspace-hack.workspace = true -itertools.workspace = true [dev-dependencies] clap.workspace = true diff --git a/crates/edit_prediction_context/src/declaration.rs b/crates/edit_prediction_context/src/declaration.rs index 9eee5a8273810441a578cf48843cf6c9be9a70f6..ad944359d54a2d4bf2ddb02df619e49a2d9fd7a8 100644 --- a/crates/edit_prediction_context/src/declaration.rs +++ b/crates/edit_prediction_context/src/declaration.rs @@ -1,10 +1,9 @@ -use gpui::{App, WeakEntity}; -use language::{Buffer, BufferSnapshot, LanguageId}; +use language::LanguageId; use project::ProjectEntryId; use std::borrow::Cow; -use std::ops::{Deref, Range}; +use std::ops::Range; use std::sync::Arc; -use text::{Anchor, Bias, OffsetRangeExt, ToOffset}; +use text::{Bias, BufferId, Rope}; use crate::outline::OutlineDeclaration; @@ -25,7 +24,9 @@ pub enum Declaration { declaration: FileDeclaration, }, Buffer { - buffer: WeakEntity, + project_entry_id: ProjectEntryId, + buffer_id: BufferId, + rope: Rope, declaration: BufferDeclaration, }, } @@ -40,88 +41,79 @@ impl Declaration { } } - pub fn project_entry_id(&self, cx: &App) -> Option { + pub fn project_entry_id(&self) -> Option { match self { Declaration::File { project_entry_id, .. } => Some(*project_entry_id), - Declaration::Buffer { buffer, .. } => buffer - .read_with(cx, |buffer, _cx| { - project::File::from_dyn(buffer.file()) - .and_then(|file| file.project_entry_id(cx)) - }) - .ok() - .flatten(), + Declaration::Buffer { + project_entry_id, .. + } => Some(*project_entry_id), } } - pub fn item_text(&self, cx: &App) -> (Cow<'_, str>, bool) { + pub fn item_text(&self) -> (Cow<'_, str>, bool) { match self { Declaration::File { declaration, .. } => ( declaration.text.as_ref().into(), declaration.text_is_truncated, ), Declaration::Buffer { - buffer, - declaration, - } => buffer - .read_with(cx, |buffer, _cx| { - let (range, is_truncated) = expand_range_to_line_boundaries_and_truncate( - &declaration.item_range, - ITEM_TEXT_TRUNCATION_LENGTH, - buffer.deref(), - ); - ( - buffer.text_for_range(range).collect::>(), - is_truncated, - ) - }) - .unwrap_or_default(), + rope, declaration, .. + } => { + let (range, is_truncated) = expand_range_to_line_boundaries_and_truncate( + &declaration.item_range, + ITEM_TEXT_TRUNCATION_LENGTH, + rope, + ); + ( + rope.chunks_in_range(range).collect::>(), + is_truncated, + ) + } } } - pub fn signature_text(&self, cx: &App) -> (Cow<'_, str>, bool) { + pub fn signature_text(&self) -> (Cow<'_, str>, bool) { match self { Declaration::File { declaration, .. } => ( declaration.text[declaration.signature_range_in_text.clone()].into(), declaration.signature_is_truncated, ), Declaration::Buffer { - buffer, - declaration, - } => buffer - .read_with(cx, |buffer, _cx| { - let (range, is_truncated) = expand_range_to_line_boundaries_and_truncate( - &declaration.signature_range, - ITEM_TEXT_TRUNCATION_LENGTH, - buffer.deref(), - ); - ( - buffer.text_for_range(range).collect::>(), - is_truncated, - ) - }) - .unwrap_or_default(), + rope, declaration, .. + } => { + let (range, is_truncated) = expand_range_to_line_boundaries_and_truncate( + &declaration.signature_range, + ITEM_TEXT_TRUNCATION_LENGTH, + rope, + ); + ( + rope.chunks_in_range(range).collect::>(), + is_truncated, + ) + } } } } -fn expand_range_to_line_boundaries_and_truncate( - range: &Range, +fn expand_range_to_line_boundaries_and_truncate( + range: &Range, limit: usize, - buffer: &text::BufferSnapshot, + rope: &Rope, ) -> (Range, bool) { - let mut point_range = range.to_point(buffer); + let mut point_range = rope.offset_to_point(range.start)..rope.offset_to_point(range.end); point_range.start.column = 0; point_range.end.row += 1; point_range.end.column = 0; - let mut item_range = point_range.to_offset(buffer); + let mut item_range = + rope.point_to_offset(point_range.start)..rope.point_to_offset(point_range.end); let is_truncated = item_range.len() > limit; if is_truncated { item_range.end = item_range.start + limit; } - item_range.end = buffer.clip_offset(item_range.end, Bias::Left); + item_range.end = rope.clip_offset(item_range.end, Bias::Left); (item_range, is_truncated) } @@ -142,14 +134,11 @@ pub struct FileDeclaration { } impl FileDeclaration { - pub fn from_outline( - declaration: OutlineDeclaration, - snapshot: &BufferSnapshot, - ) -> FileDeclaration { + pub fn from_outline(declaration: OutlineDeclaration, rope: &Rope) -> FileDeclaration { let (item_range_in_file, text_is_truncated) = expand_range_to_line_boundaries_and_truncate( &declaration.item_range, ITEM_TEXT_TRUNCATION_LENGTH, - snapshot, + rope, ); // TODO: consider logging if unexpected @@ -171,8 +160,8 @@ impl FileDeclaration { identifier: declaration.identifier, signature_range_in_text: signature_start..signature_end, signature_is_truncated, - text: snapshot - .text_for_range(item_range_in_file.clone()) + text: rope + .chunks_in_range(item_range_in_file.clone()) .collect::() .into(), text_is_truncated, @@ -185,21 +174,19 @@ impl FileDeclaration { pub struct BufferDeclaration { pub parent: Option, pub identifier: Identifier, - pub item_range: Range, - pub signature_range: Range, + pub item_range: Range, + pub signature_range: Range, } impl BufferDeclaration { - pub fn from_outline(declaration: OutlineDeclaration, snapshot: &BufferSnapshot) -> Self { + pub fn from_outline(declaration: OutlineDeclaration) -> Self { // use of anchor_before is a guess that the proper behavior is to expand to include // insertions immediately before the declaration, but not for insertions immediately after Self { parent: None, identifier: declaration.identifier, - item_range: snapshot.anchor_before(declaration.item_range.start) - ..snapshot.anchor_before(declaration.item_range.end), - signature_range: snapshot.anchor_before(declaration.signature_range.start) - ..snapshot.anchor_before(declaration.signature_range.end), + item_range: declaration.item_range, + signature_range: declaration.signature_range, } } } diff --git a/crates/edit_prediction_context/src/declaration_scoring.rs b/crates/edit_prediction_context/src/declaration_scoring.rs index df6f0f967580d0a3c819eaf2c134eb087a30e7e0..af26ea4ca46e83d82776e03a117e8f51855304ed 100644 --- a/crates/edit_prediction_context/src/declaration_scoring.rs +++ b/crates/edit_prediction_context/src/declaration_scoring.rs @@ -1,4 +1,3 @@ -use gpui::{App, Entity}; use itertools::Itertools as _; use language::BufferSnapshot; use serde::Serialize; @@ -7,8 +6,9 @@ use strum::EnumIter; use text::{OffsetRangeExt, Point, ToPoint}; use crate::{ - Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier, SyntaxIndex, + Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier, reference::{Reference, ReferenceRegion}, + syntax_index::SyntaxIndexState, text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient}, }; @@ -50,14 +50,13 @@ impl ScoredSnippet { } } -fn scored_snippets( - index: Entity, +pub fn scored_snippets( + index: &SyntaxIndexState, excerpt: &EditPredictionExcerpt, excerpt_text: &EditPredictionExcerptText, identifier_to_references: HashMap>, cursor_offset: usize, current_buffer: &BufferSnapshot, - cx: &App, ) -> Vec { let containing_range_identifier_occurrences = IdentifierOccurrences::within_string(&excerpt_text.body); @@ -74,22 +73,19 @@ fn scored_snippets( identifier_to_references .into_iter() .flat_map(|(identifier, references)| { - let declarations = index - .read(cx) - // todo! pick a limit - .declarations_for_identifier::<16>(&identifier, cx); + // todo! pick a limit + let declarations = index.declarations_for_identifier::<16>(&identifier); let declaration_count = declarations.len(); declarations .iter() .filter_map(|declaration| match declaration { Declaration::Buffer { + buffer_id, declaration: buffer_declaration, - buffer, + .. } => { - let is_same_file = buffer - .read_with(cx, |buffer, _| buffer.remote_id()) - .is_ok_and(|buffer_id| buffer_id == current_buffer.remote_id()); + let is_same_file = buffer_id == ¤t_buffer.remote_id(); if is_same_file { range_intersection( @@ -127,8 +123,7 @@ fn scored_snippets( declaration_line_distance_rank, (is_same_file, declaration_line_distance, declaration), )| { - let same_file_declaration_count = - index.read(cx).file_declaration_count(declaration); + let same_file_declaration_count = index.file_declaration_count(declaration); score_snippet( &identifier, @@ -143,7 +138,6 @@ fn scored_snippets( &adjacent_identifier_occurrences, cursor_point, current_buffer, - cx, ) }, ) @@ -177,7 +171,6 @@ fn score_snippet( adjacent_identifier_occurrences: &IdentifierOccurrences, cursor: Point, current_buffer: &BufferSnapshot, - cx: &App, ) -> Option { let is_referenced_nearby = references .iter() @@ -195,10 +188,9 @@ fn score_snippet( .min() .unwrap(); - let item_source_occurrences = - IdentifierOccurrences::within_string(&declaration.item_text(cx).0); + let item_source_occurrences = IdentifierOccurrences::within_string(&declaration.item_text().0); let item_signature_occurrences = - IdentifierOccurrences::within_string(&declaration.signature_text(cx).0); + IdentifierOccurrences::within_string(&declaration.signature_text().0); let containing_range_vs_item_jaccard = jaccard_similarity( containing_range_identifier_occurrences, &item_source_occurrences, @@ -285,7 +277,7 @@ impl ScoreInputs { // Score related to how likely this is the correct declaration, range 0 to 1 let accuracy_score = if self.is_same_file { // TODO: use declaration_line_distance_rank - (0.5 / self.same_file_declaration_count as f32) + 1.0 / self.same_file_declaration_count as f32 } else { 1.0 / self.declaration_count as f32 }; @@ -309,173 +301,3 @@ impl ScoreInputs { } } } - -#[cfg(test)] -mod tests { - use super::*; - use std::sync::Arc; - - use gpui::{TestAppContext, prelude::*}; - use indoc::indoc; - use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust}; - use project::{FakeFs, Project}; - use serde_json::json; - use settings::SettingsStore; - use text::ToOffset; - use util::path; - - use crate::{EditPredictionExcerptOptions, references_in_excerpt}; - - #[gpui::test] - async fn test_call_site(cx: &mut TestAppContext) { - let (project, index, _rust_lang_id) = init_test(cx).await; - - let buffer = project - .update(cx, |project, cx| { - let project_path = project.find_project_path("c.rs", cx).unwrap(); - project.open_buffer(project_path, cx) - }) - .await - .unwrap(); - - cx.run_until_parked(); - - // first process_data call site - let cursor_point = language::Point::new(8, 21); - let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); - let excerpt = EditPredictionExcerpt::select_from_buffer( - cursor_point, - &buffer_snapshot, - &EditPredictionExcerptOptions { - max_bytes: 40, - min_bytes: 10, - target_before_cursor_over_total_bytes: 0.5, - include_parent_signatures: false, - }, - ) - .unwrap(); - let excerpt_text = excerpt.text(&buffer_snapshot); - let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer_snapshot); - let cursor_offset = cursor_point.to_offset(&buffer_snapshot); - - let snippets = cx.update(|cx| { - scored_snippets( - index, - &excerpt, - &excerpt_text, - references, - cursor_offset, - &buffer_snapshot, - cx, - ) - }); - - assert_eq!(snippets.len(), 1); - assert_eq!(snippets[0].identifier.name.as_ref(), "process_data"); - drop(buffer); - } - - async fn init_test( - cx: &mut TestAppContext, - ) -> (Entity, Entity, LanguageId) { - cx.update(|cx| { - let settings_store = SettingsStore::test(cx); - cx.set_global(settings_store); - language::init(cx); - Project::init_settings(cx); - }); - - let fs = FakeFs::new(cx.executor()); - fs.insert_tree( - path!("/root"), - json!({ - "a.rs": indoc! {r#" - fn main() { - let x = 1; - let y = 2; - let z = add(x, y); - println!("Result: {}", z); - } - - fn add(a: i32, b: i32) -> i32 { - a + b - } - "#}, - "b.rs": indoc! {" - pub struct Config { - pub name: String, - pub value: i32, - } - - impl Config { - pub fn new(name: String, value: i32) -> Self { - Config { name, value } - } - } - "}, - "c.rs": indoc! {r#" - use std::collections::HashMap; - - fn main() { - let args: Vec = std::env::args().collect(); - let data: Vec = args[1..] - .iter() - .filter_map(|s| s.parse().ok()) - .collect(); - let result = process_data(data); - println!("{:?}", result); - } - - fn process_data(data: Vec) -> HashMap { - let mut counts = HashMap::new(); - for value in data { - *counts.entry(value).or_insert(0) += 1; - } - counts - } - - #[cfg(test)] - mod tests { - use super::*; - - #[test] - fn test_process_data() { - let data = vec![1, 2, 2, 3]; - let result = process_data(data); - assert_eq!(result.get(&2), Some(&2)); - } - } - "#} - }), - ) - .await; - let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; - let language_registry = project.read_with(cx, |project, _| project.languages().clone()); - let lang = rust_lang(); - let lang_id = lang.id(); - language_registry.add(Arc::new(lang)); - - let index = cx.new(|cx| SyntaxIndex::new(&project, cx)); - cx.run_until_parked(); - - (project, index, lang_id) - } - - fn rust_lang() -> Language { - Language::new( - LanguageConfig { - name: "Rust".into(), - matcher: LanguageMatcher { - path_suffixes: vec!["rs".to_string()], - ..Default::default() - }, - ..Default::default() - }, - Some(tree_sitter_rust::LANGUAGE.into()), - ) - .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm")) - .unwrap() - .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) - .unwrap() - } -} diff --git a/crates/edit_prediction_context/src/edit_prediction_context.rs b/crates/edit_prediction_context/src/edit_prediction_context.rs index ef14896c27d6acbe8f06364ae566ea1b8ef588cc..ff999964f4a7b4634d80680aaa06e4d09c7948a8 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context.rs @@ -8,5 +8,213 @@ mod text_similarity; pub use declaration::{BufferDeclaration, Declaration, FileDeclaration, Identifier}; pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText}; +use gpui::{App, AppContext as _, Entity, Task}; +use language::BufferSnapshot; pub use reference::references_in_excerpt; pub use syntax_index::SyntaxIndex; +use text::{Point, ToOffset as _}; + +use crate::declaration_scoring::{ScoredSnippet, scored_snippets}; + +pub struct EditPredictionContext { + excerpt: EditPredictionExcerpt, + excerpt_text: EditPredictionExcerptText, + snippets: Vec, +} + +impl EditPredictionContext { + pub fn gather( + cursor_point: Point, + buffer: BufferSnapshot, + excerpt_options: EditPredictionExcerptOptions, + syntax_index: Entity, + cx: &mut App, + ) -> Task { + let index_state = syntax_index.read_with(cx, |index, _cx| index.state().clone()); + cx.background_spawn(async move { + let index_state = index_state.lock().await; + + let excerpt = + EditPredictionExcerpt::select_from_buffer(cursor_point, &buffer, &excerpt_options) + .unwrap(); + let excerpt_text = excerpt.text(&buffer); + let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer); + let cursor_offset = cursor_point.to_offset(&buffer); + + let snippets = scored_snippets( + &index_state, + &excerpt, + &excerpt_text, + references, + cursor_offset, + &buffer, + ); + + Self { + excerpt, + excerpt_text, + snippets, + } + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::Arc; + + use gpui::{Entity, TestAppContext}; + use indoc::indoc; + use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust}; + use project::{FakeFs, Project}; + use serde_json::json; + use settings::SettingsStore; + use util::path; + + use crate::{EditPredictionExcerptOptions, SyntaxIndex}; + + #[gpui::test] + async fn test_call_site(cx: &mut TestAppContext) { + let (project, index, _rust_lang_id) = init_test(cx).await; + + let buffer = project + .update(cx, |project, cx| { + let project_path = project.find_project_path("c.rs", cx).unwrap(); + project.open_buffer(project_path, cx) + }) + .await + .unwrap(); + + cx.run_until_parked(); + + // first process_data call site + let cursor_point = language::Point::new(8, 21); + let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot()); + + let context = cx + .update(|cx| { + EditPredictionContext::gather( + cursor_point, + buffer_snapshot, + EditPredictionExcerptOptions { + max_bytes: 40, + min_bytes: 10, + target_before_cursor_over_total_bytes: 0.5, + include_parent_signatures: false, + }, + index, + cx, + ) + }) + .await; + + assert_eq!(context.snippets.len(), 1); + assert_eq!(context.snippets[0].identifier.name.as_ref(), "process_data"); + drop(buffer); + } + + async fn init_test( + cx: &mut TestAppContext, + ) -> (Entity, Entity, LanguageId) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/root"), + json!({ + "a.rs": indoc! {r#" + fn main() { + let x = 1; + let y = 2; + let z = add(x, y); + println!("Result: {}", z); + } + + fn add(a: i32, b: i32) -> i32 { + a + b + } + "#}, + "b.rs": indoc! {" + pub struct Config { + pub name: String, + pub value: i32, + } + + impl Config { + pub fn new(name: String, value: i32) -> Self { + Config { name, value } + } + } + "}, + "c.rs": indoc! {r#" + use std::collections::HashMap; + + fn main() { + let args: Vec = std::env::args().collect(); + let data: Vec = args[1..] + .iter() + .filter_map(|s| s.parse().ok()) + .collect(); + let result = process_data(data); + println!("{:?}", result); + } + + fn process_data(data: Vec) -> HashMap { + let mut counts = HashMap::new(); + for value in data { + *counts.entry(value).or_insert(0) += 1; + } + counts + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn test_process_data() { + let data = vec![1, 2, 2, 3]; + let result = process_data(data); + assert_eq!(result.get(&2), Some(&2)); + } + } + "#} + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _| project.languages().clone()); + let lang = rust_lang(); + let lang_id = lang.id(); + language_registry.add(Arc::new(lang)); + + let index = cx.new(|cx| SyntaxIndex::new(&project, cx)); + cx.run_until_parked(); + + (project, index, lang_id) + } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm")) + .unwrap() + .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) + .unwrap() + } +} diff --git a/crates/edit_prediction_context/src/syntax_index.rs b/crates/edit_prediction_context/src/syntax_index.rs index 10059a18f9785e2e0574844f70a3bba37723cd2f..bef993d0c50dd06602b0343082929cb8677a3bb8 100644 --- a/crates/edit_prediction_context/src/syntax_index.rs +++ b/crates/edit_prediction_context/src/syntax_index.rs @@ -1,10 +1,14 @@ +use std::sync::Arc; + use collections::{HashMap, HashSet}; +use futures::lock::Mutex; use gpui::{App, AppContext as _, Context, Entity, Task, WeakEntity}; use language::{Buffer, BufferEvent}; use project::buffer_store::{BufferStore, BufferStoreEvent}; use project::worktree_store::{WorktreeStore, WorktreeStoreEvent}; use project::{PathChange, Project, ProjectEntryId, ProjectPath}; use slotmap::SlotMap; +use text::BufferId; use util::{ResultExt as _, debug_panic, some_or_debug_panic}; use crate::declaration::{ @@ -15,6 +19,8 @@ use crate::outline::declarations_in_buffer; // TODO: // // * Skip for remote projects +// +// * Consider making SyntaxIndex not an Entity. // Potential future improvements: // @@ -34,13 +40,19 @@ use crate::outline::declarations_in_buffer; // * Concurrent slotmap // // * Use queue for parsing +// pub struct SyntaxIndex { + state: Arc>, + project: WeakEntity, +} + +#[derive(Default)] +pub struct SyntaxIndexState { declarations: SlotMap, identifiers: HashMap>, files: HashMap, - buffers: HashMap, BufferState>, - project: WeakEntity, + buffers: HashMap, } #[derive(Debug, Default)] @@ -58,11 +70,8 @@ struct BufferState { impl SyntaxIndex { pub fn new(project: &Entity, cx: &mut Context) -> Self { let mut this = Self { - declarations: SlotMap::with_key(), - identifiers: HashMap::default(), project: project.downgrade(), - files: HashMap::default(), - buffers: HashMap::default(), + state: Arc::new(Mutex::new(SyntaxIndexState::default())), }; let worktree_store = project.read(cx).worktree_store(); @@ -97,90 +106,6 @@ impl SyntaxIndex { this } - pub fn declaration(&self, id: DeclarationId) -> Option<&Declaration> { - self.declarations.get(id) - } - - pub fn declarations_for_identifier( - &self, - identifier: &Identifier, - cx: &App, - ) -> Vec { - // make sure to not have a large stack allocation - assert!(N < 32); - - let Some(declaration_ids) = self.identifiers.get(&identifier) else { - return vec![]; - }; - - let mut result = Vec::with_capacity(N); - let mut included_buffer_entry_ids = arrayvec::ArrayVec::<_, N>::new(); - let mut file_declarations = Vec::new(); - - for declaration_id in declaration_ids { - let declaration = self.declarations.get(*declaration_id); - let Some(declaration) = some_or_debug_panic(declaration) else { - continue; - }; - match declaration { - Declaration::Buffer { buffer, .. } => { - if let Ok(Some(entry_id)) = buffer.read_with(cx, |buffer, cx| { - project::File::from_dyn(buffer.file()).and_then(|f| f.project_entry_id(cx)) - }) { - included_buffer_entry_ids.push(entry_id); - result.push(declaration.clone()); - if result.len() == N { - return result; - } - } - } - Declaration::File { - project_entry_id, .. - } => { - if !included_buffer_entry_ids.contains(project_entry_id) { - file_declarations.push(declaration.clone()); - } - } - } - } - - for declaration in file_declarations { - match declaration { - Declaration::File { - project_entry_id, .. - } => { - if !included_buffer_entry_ids.contains(&project_entry_id) { - result.push(declaration); - - if result.len() == N { - return result; - } - } - } - Declaration::Buffer { .. } => {} - } - } - - result - } - - pub fn file_declaration_count(&self, declaration: &Declaration) -> usize { - match declaration { - Declaration::File { - project_entry_id, .. - } => self - .files - .get(project_entry_id) - .map(|file_state| file_state.declarations.len()) - .unwrap_or_default(), - Declaration::Buffer { buffer, .. } => self - .buffers - .get(buffer) - .map(|buffer_state| buffer_state.declarations.len()) - .unwrap_or_default(), - } - } - fn handle_worktree_store_event( &mut self, _worktree_store: Entity, @@ -190,21 +115,33 @@ impl SyntaxIndex { use WorktreeStoreEvent::*; match event { WorktreeUpdatedEntries(worktree_id, updated_entries_set) => { - for (path, entry_id, path_change) in updated_entries_set.iter() { - if let PathChange::Removed = path_change { - self.files.remove(entry_id); - } else { - let project_path = ProjectPath { - worktree_id: *worktree_id, - path: path.clone(), - }; - self.update_file(*entry_id, project_path, cx); + let state = Arc::downgrade(&self.state); + let worktree_id = *worktree_id; + let updated_entries_set = updated_entries_set.clone(); + cx.spawn(async move |this, cx| { + let Some(state) = state.upgrade() else { return }; + for (path, entry_id, path_change) in updated_entries_set.iter() { + if let PathChange::Removed = path_change { + state.lock().await.files.remove(entry_id); + } else { + let project_path = ProjectPath { + worktree_id, + path: path.clone(), + }; + this.update(cx, |this, cx| { + this.update_file(*entry_id, project_path, cx); + }) + .ok(); + } } - } + }) + .detach(); } WorktreeDeletedEntry(_worktree_id, project_entry_id) => { - // TODO: Is this needed? - self.files.remove(project_entry_id); + let project_entry_id = *project_entry_id; + self.with_state(cx, move |state| { + state.files.remove(&project_entry_id); + }) } _ => {} } @@ -226,15 +163,42 @@ impl SyntaxIndex { } } + pub fn state(&self) -> &Arc> { + &self.state + } + + fn with_state(&self, cx: &mut App, f: impl FnOnce(&mut SyntaxIndexState) + Send + 'static) { + if let Some(mut state) = self.state.try_lock() { + f(&mut state); + return; + } + let state = Arc::downgrade(&self.state); + cx.background_spawn(async move { + let Some(state) = state.upgrade() else { + return None; + }; + let mut state = state.lock().await; + Some(f(&mut state)) + }) + .detach(); + } + fn register_buffer(&mut self, buffer: &Entity, cx: &mut Context) { - self.buffers - .insert(buffer.downgrade(), BufferState::default()); - let weak_buf = buffer.downgrade(); - cx.observe_release(buffer, move |this, _buffer, _cx| { - this.buffers.remove(&weak_buf); + let buffer_id = buffer.read(cx).remote_id(); + cx.observe_release(buffer, move |this, _buffer, cx| { + this.with_state(cx, move |state| { + if let Some(buffer_state) = state.buffers.remove(&buffer_id) { + SyntaxIndexState::remove_buffer_declarations( + &buffer_state.declarations, + &mut state.declarations, + &mut state.identifiers, + ); + } + }) }) .detach(); cx.subscribe(buffer, Self::handle_buffer_event).detach(); + self.update_buffer(buffer.clone(), cx); } @@ -250,10 +214,19 @@ impl SyntaxIndex { } } - fn update_buffer(&mut self, buffer: Entity, cx: &Context) { - let mut parse_status = buffer.read(cx).parse_status(); + fn update_buffer(&mut self, buffer_entity: Entity, cx: &mut Context) { + let buffer = buffer_entity.read(cx); + + let Some(project_entry_id) = + project::File::from_dyn(buffer.file()).and_then(|f| f.project_entry_id(cx)) + else { + return; + }; + let buffer_id = buffer.remote_id(); + + let mut parse_status = buffer.parse_status(); let snapshot_task = cx.spawn({ - let weak_buffer = buffer.downgrade(); + let weak_buffer = buffer_entity.downgrade(); async move |_, cx| { while *parse_status.borrow() != language::ParseStatus::Idle { parse_status.changed().await?; @@ -264,75 +237,72 @@ impl SyntaxIndex { let parse_task = cx.background_spawn(async move { let snapshot = snapshot_task.await?; + let rope = snapshot.text.as_rope().clone(); - anyhow::Ok( + anyhow::Ok(( + rope, declarations_in_buffer(&snapshot) .into_iter() - .map(|item| { - ( - item.parent_index, - BufferDeclaration::from_outline(item, &snapshot), - ) - }) + .map(|item| (item.parent_index, BufferDeclaration::from_outline(item))) .collect::>(), - ) + )) }); let task = cx.spawn({ - let weak_buffer = buffer.downgrade(); async move |this, cx| { - let Ok(declarations) = parse_task.await else { + let Ok((rope, declarations)) = parse_task.await else { return; }; - this.update(cx, |this, _cx| { - let buffer_state = this - .buffers - .entry(weak_buffer.clone()) - .or_insert_with(Default::default); - - for old_declaration_id in &buffer_state.declarations { - let Some(declaration) = this.declarations.remove(*old_declaration_id) - else { - debug_panic!("declaration not found"); - continue; - }; - if let Some(identifier_declarations) = - this.identifiers.get_mut(declaration.identifier()) - { - identifier_declarations.remove(old_declaration_id); + this.update(cx, move |this, cx| { + this.with_state(cx, move |state| { + let buffer_state = state + .buffers + .entry(buffer_id) + .or_insert_with(Default::default); + + SyntaxIndexState::remove_buffer_declarations( + &buffer_state.declarations, + &mut state.declarations, + &mut state.identifiers, + ); + + let mut new_ids = Vec::with_capacity(declarations.len()); + state.declarations.reserve(declarations.len()); + for (parent_index, mut declaration) in declarations { + declaration.parent = parent_index + .and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied())); + + let identifier = declaration.identifier.clone(); + let declaration_id = state.declarations.insert(Declaration::Buffer { + rope: rope.clone(), + buffer_id, + declaration, + project_entry_id, + }); + new_ids.push(declaration_id); + + state + .identifiers + .entry(identifier) + .or_default() + .insert(declaration_id); } - } - let mut new_ids = Vec::with_capacity(declarations.len()); - this.declarations.reserve(declarations.len()); - for (parent_index, mut declaration) in declarations { - declaration.parent = parent_index - .and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied())); - - let identifier = declaration.identifier.clone(); - let declaration_id = this.declarations.insert(Declaration::Buffer { - buffer: weak_buffer.clone(), - declaration, - }); - new_ids.push(declaration_id); - - this.identifiers - .entry(identifier) - .or_default() - .insert(declaration_id); - } - - buffer_state.declarations = new_ids; + buffer_state.declarations = new_ids; + }); }) .ok(); } }); - self.buffers - .entry(buffer.downgrade()) - .or_insert_with(Default::default) - .task = Some(task); + self.with_state(cx, move |state| { + state + .buffers + .entry(buffer_id) + .or_insert_with(Default::default) + .task = Some(task) + }); } fn update_file( @@ -376,14 +346,10 @@ impl SyntaxIndex { let parse_task = cx.background_spawn(async move { let snapshot = snapshot_task.await?; + let rope = snapshot.as_rope(); let declarations = declarations_in_buffer(&snapshot) .into_iter() - .map(|item| { - ( - item.parent_index, - FileDeclaration::from_outline(item, &snapshot), - ) - }) + .map(|item| (item.parent_index, FileDeclaration::from_outline(item, rope))) .collect::>(); anyhow::Ok(declarations) }); @@ -394,52 +360,158 @@ impl SyntaxIndex { let Ok(declarations) = parse_task.await else { return; }; - this.update(cx, |this, _cx| { - let file_state = this.files.entry(entry_id).or_insert_with(Default::default); - - for old_declaration_id in &file_state.declarations { - let Some(declaration) = this.declarations.remove(*old_declaration_id) - else { - debug_panic!("declaration not found"); - continue; - }; - if let Some(identifier_declarations) = - this.identifiers.get_mut(declaration.identifier()) - { - identifier_declarations.remove(old_declaration_id); + this.update(cx, |this, cx| { + this.with_state(cx, move |state| { + let file_state = + state.files.entry(entry_id).or_insert_with(Default::default); + + for old_declaration_id in &file_state.declarations { + let Some(declaration) = state.declarations.remove(*old_declaration_id) + else { + debug_panic!("declaration not found"); + continue; + }; + if let Some(identifier_declarations) = + state.identifiers.get_mut(declaration.identifier()) + { + identifier_declarations.remove(old_declaration_id); + } } - } - let mut new_ids = Vec::with_capacity(declarations.len()); - this.declarations.reserve(declarations.len()); + let mut new_ids = Vec::with_capacity(declarations.len()); + state.declarations.reserve(declarations.len()); + + for (parent_index, mut declaration) in declarations { + declaration.parent = parent_index + .and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied())); + + let identifier = declaration.identifier.clone(); + let declaration_id = state.declarations.insert(Declaration::File { + project_entry_id: entry_id, + declaration, + }); + new_ids.push(declaration_id); + + state + .identifiers + .entry(identifier) + .or_default() + .insert(declaration_id); + } - for (parent_index, mut declaration) in declarations { - declaration.parent = parent_index - .and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied())); + file_state.declarations = new_ids; + }); + }) + .ok(); + } + }); - let identifier = declaration.identifier.clone(); - let declaration_id = this.declarations.insert(Declaration::File { - project_entry_id: entry_id, - declaration, - }); - new_ids.push(declaration_id); + self.with_state(cx, move |state| { + state + .files + .entry(entry_id) + .or_insert_with(Default::default) + .task = Some(task); + }); + } +} - this.identifiers - .entry(identifier) - .or_default() - .insert(declaration_id); +impl SyntaxIndexState { + pub fn declaration(&self, id: DeclarationId) -> Option<&Declaration> { + self.declarations.get(id) + } + + pub fn declarations_for_identifier( + &self, + identifier: &Identifier, + ) -> Vec { + // make sure to not have a large stack allocation + assert!(N < 32); + + let Some(declaration_ids) = self.identifiers.get(&identifier) else { + return vec![]; + }; + + let mut result = Vec::with_capacity(N); + let mut included_buffer_entry_ids = arrayvec::ArrayVec::<_, N>::new(); + let mut file_declarations = Vec::new(); + + for declaration_id in declaration_ids { + let declaration = self.declarations.get(*declaration_id); + let Some(declaration) = some_or_debug_panic(declaration) else { + continue; + }; + match declaration { + Declaration::Buffer { + project_entry_id, .. + } => { + included_buffer_entry_ids.push(*project_entry_id); + result.push(declaration.clone()); + if result.len() == N { + return result; + } + } + Declaration::File { + project_entry_id, .. + } => { + if !included_buffer_entry_ids.contains(&project_entry_id) { + file_declarations.push(declaration.clone()); } + } + } + } - file_state.declarations = new_ids; - }) - .ok(); + for declaration in file_declarations { + match declaration { + Declaration::File { + project_entry_id, .. + } => { + if !included_buffer_entry_ids.contains(&project_entry_id) { + result.push(declaration); + + if result.len() == N { + return result; + } + } + } + Declaration::Buffer { .. } => {} } - }); + } - self.files - .entry(entry_id) - .or_insert_with(Default::default) - .task = Some(task); + result + } + + pub fn file_declaration_count(&self, declaration: &Declaration) -> usize { + match declaration { + Declaration::File { + project_entry_id, .. + } => self + .files + .get(project_entry_id) + .map(|file_state| file_state.declarations.len()) + .unwrap_or_default(), + Declaration::Buffer { buffer_id, .. } => self + .buffers + .get(buffer_id) + .map(|buffer_state| buffer_state.declarations.len()) + .unwrap_or_default(), + } + } + + fn remove_buffer_declarations( + old_declaration_ids: &[DeclarationId], + declarations: &mut SlotMap, + identifiers: &mut HashMap>, + ) { + for old_declaration_id in old_declaration_ids { + let Some(declaration) = declarations.remove(*old_declaration_id) else { + debug_panic!("declaration not found"); + continue; + }; + if let Some(identifier_declarations) = identifiers.get_mut(declaration.identifier()) { + identifier_declarations.remove(old_declaration_id); + } + } } } @@ -448,11 +520,10 @@ mod tests { use super::*; use std::{path::Path, sync::Arc}; - use futures::channel::oneshot; use gpui::TestAppContext; use indoc::indoc; use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust}; - use project::{FakeFs, Project, ProjectItem}; + use project::{FakeFs, Project}; use serde_json::json; use settings::SettingsStore; use text::OffsetRangeExt as _; @@ -468,8 +539,10 @@ mod tests { language_id: rust_lang_id, }; - index.read_with(cx, |index, cx| { - let decls = index.declarations_for_identifier::<8>(&main, cx); + let index_state = index.read_with(cx, |index, _cx| index.state().clone()); + let index_state = index_state.lock().await; + cx.update(|cx| { + let decls = index_state.declarations_for_identifier::<8>(&main); assert_eq!(decls.len(), 2); let decl = expect_file_decl("c.rs", &decls[0], &project, cx); @@ -490,15 +563,17 @@ mod tests { language_id: rust_lang_id, }; - index.read_with(cx, |index, cx| { - let decls = index.declarations_for_identifier::<8>(&test_process_data, cx); + let index_state = index.read_with(cx, |index, _cx| index.state().clone()); + let index_state = index_state.lock().await; + cx.update(|cx| { + let decls = index_state.declarations_for_identifier::<8>(&test_process_data); assert_eq!(decls.len(), 1); let decl = expect_file_decl("c.rs", &decls[0], &project, cx); assert_eq!(decl.identifier, test_process_data); let parent_id = decl.parent.unwrap(); - let parent = index.declaration(parent_id).unwrap(); + let parent = index_state.declaration(parent_id).unwrap(); let parent_decl = expect_file_decl("c.rs", &parent, &project, cx); assert_eq!( parent_decl.identifier, @@ -529,16 +604,18 @@ mod tests { cx.run_until_parked(); - index.read_with(cx, |index, cx| { - let decls = index.declarations_for_identifier::<8>(&test_process_data, cx); + let index_state = index.read_with(cx, |index, _cx| index.state().clone()); + let index_state = index_state.lock().await; + cx.update(|cx| { + let decls = index_state.declarations_for_identifier::<8>(&test_process_data); assert_eq!(decls.len(), 1); - let decl = expect_buffer_decl("c.rs", &decls[0], cx); + let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx); assert_eq!(decl.identifier, test_process_data); let parent_id = decl.parent.unwrap(); - let parent = index.declaration(parent_id).unwrap(); - let parent_decl = expect_buffer_decl("c.rs", &parent, cx); + let parent = index_state.declaration(parent_id).unwrap(); + let parent_decl = expect_buffer_decl("c.rs", &parent, &project, cx); assert_eq!( parent_decl.identifier, Identifier { @@ -556,16 +633,13 @@ mod tests { async fn test_declarations_limt(cx: &mut TestAppContext) { let (_, index, rust_lang_id) = init_test(cx).await; - index.read_with(cx, |index, cx| { - let decls = index.declarations_for_identifier::<1>( - &Identifier { - name: "main".into(), - language_id: rust_lang_id, - }, - cx, - ); - assert_eq!(decls.len(), 1); + let index_state = index.read_with(cx, |index, _cx| index.state().clone()); + let index_state = index_state.lock().await; + let decls = index_state.declarations_for_identifier::<1>(&Identifier { + name: "main".into(), + language_id: rust_lang_id, }); + assert_eq!(decls.len(), 1); } #[gpui::test] @@ -587,31 +661,31 @@ mod tests { cx.run_until_parked(); - index.read_with(cx, |index, cx| { - let decls = index.declarations_for_identifier::<8>(&main, cx); - assert_eq!(decls.len(), 2); - let decl = expect_buffer_decl("c.rs", &decls[0], cx); - assert_eq!(decl.identifier, main); - assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..279); + let index_state_arc = index.read_with(cx, |index, _cx| index.state().clone()); + { + let index_state = index_state_arc.lock().await; - expect_file_decl("a.rs", &decls[1], &project, cx); - }); + cx.update(|cx| { + let decls = index_state.declarations_for_identifier::<8>(&main); + assert_eq!(decls.len(), 2); + let decl = expect_buffer_decl("c.rs", &decls[0], &project, cx); + assert_eq!(decl.identifier, main); + assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..279); + + expect_file_decl("a.rs", &decls[1], &project, cx); + }); + } // Drop the buffer and wait for release - let (release_tx, release_rx) = oneshot::channel(); - cx.update(|cx| { - cx.observe_release(&buffer, |_, _| { - release_tx.send(()).ok(); - }) - .detach(); + cx.update(|_| { + drop(buffer); }); - drop(buffer); - cx.run_until_parked(); - release_rx.await.ok(); cx.run_until_parked(); - index.read_with(cx, |index, cx| { - let decls = index.declarations_for_identifier::<8>(&main, cx); + let index_state = index_state_arc.lock().await; + + cx.update(|cx| { + let decls = index_state.declarations_for_identifier::<8>(&main); assert_eq!(decls.len(), 2); expect_file_decl("c.rs", &decls[0], &project, cx); expect_file_decl("a.rs", &decls[1], &project, cx); @@ -621,24 +695,20 @@ mod tests { fn expect_buffer_decl<'a>( path: &str, declaration: &'a Declaration, + project: &Entity, cx: &App, ) -> &'a BufferDeclaration { if let Declaration::Buffer { declaration, - buffer, + project_entry_id, + .. } = declaration { - assert_eq!( - buffer - .upgrade() - .unwrap() - .read(cx) - .project_path(cx) - .unwrap() - .path - .as_ref(), - Path::new(path), - ); + let project_path = project + .read(cx) + .path_for_entry(*project_entry_id, cx) + .unwrap(); + assert_eq!(project_path.path.as_ref(), Path::new(path),); declaration } else { panic!("Expected a buffer declaration, found {:?}", declaration);