diff --git a/Cargo.lock b/Cargo.lock index 4709dd34f4ab090d06e49827fc3b3b04eece6659..43a2fc4041fbf76b57e62d335e94c695ef07fc12 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5138,11 +5138,19 @@ dependencies = [ name = "edit_prediction_context" version = "0.1.0" dependencies = [ + "anyhow", + "arrayvec", + "collections", + "futures 0.3.31", "gpui", "indoc", "language", "log", "pretty_assertions", + "project", + "serde_json", + "settings", + "slotmap", "text", "tree-sitter", "util", diff --git a/Cargo.toml b/Cargo.toml index d3bbb85f3dbb90e52fc9b95573f6367a303c6479..e69885a835f5f579ac5ac5fa0063f5291a1d01f5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -634,6 +634,7 @@ sha2 = "0.10" shellexpand = "2.1.0" shlex = "1.3.0" simplelog = "0.12.2" +slotmap = "1.0.6" smallvec = { version = "1.6", features = ["union"] } smol = "2.0" sqlformat = "0.2" diff --git a/crates/edit_prediction_context/Cargo.toml b/crates/edit_prediction_context/Cargo.toml index 6729dcd39b67c5374d745b1811b73a8b8af4f2aa..ad455b0a4ecb1746debafd23f0503b4365f9a0cf 100644 --- a/crates/edit_prediction_context/Cargo.toml +++ b/crates/edit_prediction_context/Cargo.toml @@ -12,18 +12,28 @@ workspace = true path = "src/edit_prediction_context.rs" [dependencies] +anyhow.workspace = true +arrayvec.workspace = true +collections.workspace = true +gpui.workspace = true language.workspace = true -workspace-hack.workspace = true -tree-sitter.workspace = true -text.workspace = true log.workspace = true +project.workspace = true +slotmap.workspace = true +text.workspace = true +tree-sitter.workspace = true util.workspace = true +workspace-hack.workspace = true [dev-dependencies] +futures.workspace = true gpui = { workspace = true, features = ["test-support"] } indoc.workspace = true language = { workspace = true, features = ["test-support"] } pretty_assertions.workspace = true +project = {workspace= true, features = ["test-support"]} +serde_json.workspace = true +settings = {workspace= true, features = ["test-support"]} text = { workspace = true, features = ["test-support"] } util = { workspace = true, features = ["test-support"] } zlog.workspace = true diff --git a/crates/edit_prediction_context/src/edit_prediction_context.rs b/crates/edit_prediction_context/src/edit_prediction_context.rs index 7e6cad45a8032cb1afaafd95eb10c80e61cff097..acfb89880c3ed9e7b1ebcacd4b5fa313830165ba 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context.rs @@ -1,3 +1,8 @@ mod excerpt; +mod outline; +mod reference; +mod tree_sitter_index; -pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions}; +pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText}; +pub use reference::references_in_excerpt; +pub use tree_sitter_index::{BufferDeclaration, Declaration, FileDeclaration, TreeSitterIndex}; diff --git a/crates/edit_prediction_context/src/excerpt.rs b/crates/edit_prediction_context/src/excerpt.rs index 5a20da76f71834d226707fe36f93d3667158372c..c6caa6a1b7b4076cf739c1ac198656b9fba431a6 100644 --- a/crates/edit_prediction_context/src/excerpt.rs +++ b/crates/edit_prediction_context/src/excerpt.rs @@ -38,7 +38,28 @@ pub struct EditPredictionExcerpt { pub size: usize, } +#[derive(Clone)] +pub struct EditPredictionExcerptText { + pub body: String, + pub parent_signatures: Vec, +} + impl EditPredictionExcerpt { + pub fn text(&self, buffer: &BufferSnapshot) -> EditPredictionExcerptText { + let body = buffer + .text_for_range(self.range.clone()) + .collect::(); + let parent_signatures = self + .parent_signature_ranges + .iter() + .map(|range| buffer.text_for_range(range.clone()).collect::()) + .collect(); + EditPredictionExcerptText { + body, + parent_signatures, + } + } + /// Selects an excerpt around a buffer position, attempting to choose logical boundaries based /// on TreeSitter structure and approximately targeting a goal ratio of bytesbefore vs after the /// cursor. When `include_parent_signatures` is true, the excerpt also includes the signatures diff --git a/crates/edit_prediction_context/src/outline.rs b/crates/edit_prediction_context/src/outline.rs new file mode 100644 index 0000000000000000000000000000000000000000..492352add1fd4c666eab3b12989f9b801d03570f --- /dev/null +++ b/crates/edit_prediction_context/src/outline.rs @@ -0,0 +1,130 @@ +use language::{BufferSnapshot, LanguageId, SyntaxMapMatches}; +use std::{cmp::Reverse, ops::Range, sync::Arc}; + +// TODO: +// +// * how to handle multiple name captures? for now last one wins +// +// * annotation ranges +// +// * new "signature" capture for outline queries +// +// * Check parent behavior of "int x, y = 0" declarations in a test + +pub struct OutlineDeclaration { + pub parent_index: Option, + pub identifier: Identifier, + pub item_range: Range, + pub signature_range: Range, +} + +#[derive(Debug, Clone, Eq, PartialEq, Hash)] +pub struct Identifier { + pub name: Arc, + pub language_id: LanguageId, +} + +pub fn declarations_in_buffer(buffer: &BufferSnapshot) -> Vec { + declarations_overlapping_range(0..buffer.len(), buffer) +} + +pub fn declarations_overlapping_range( + range: Range, + buffer: &BufferSnapshot, +) -> Vec { + let mut declarations = OutlineIterator::new(range, buffer).collect::>(); + declarations.sort_unstable_by_key(|item| (item.item_range.start, Reverse(item.item_range.end))); + + let mut parent_stack: Vec<(usize, Range)> = Vec::new(); + for (index, declaration) in declarations.iter_mut().enumerate() { + while let Some((top_parent_index, top_parent_range)) = parent_stack.last() { + if declaration.item_range.start >= top_parent_range.end { + parent_stack.pop(); + } else { + declaration.parent_index = Some(*top_parent_index); + break; + } + } + parent_stack.push((index, declaration.item_range.clone())); + } + declarations +} + +/// Iterates outline items without being ordered w.r.t. nested items and without populating +/// `parent`. +pub struct OutlineIterator<'a> { + buffer: &'a BufferSnapshot, + matches: SyntaxMapMatches<'a>, +} + +impl<'a> OutlineIterator<'a> { + pub fn new(range: Range, buffer: &'a BufferSnapshot) -> Self { + let matches = buffer.syntax.matches(range, &buffer.text, |grammar| { + grammar.outline_config.as_ref().map(|c| &c.query) + }); + + Self { buffer, matches } + } +} + +impl<'a> Iterator for OutlineIterator<'a> { + type Item = OutlineDeclaration; + + fn next(&mut self) -> Option { + while let Some(mat) = self.matches.peek() { + let config = self.matches.grammars()[mat.grammar_index] + .outline_config + .as_ref() + .unwrap(); + + let mut name_range = None; + let mut item_range = None; + let mut signature_start = None; + let mut signature_end = None; + + let mut add_to_signature = |range: Range| { + if signature_start.is_none() { + signature_start = Some(range.start); + } + signature_end = Some(range.end); + }; + + for capture in mat.captures { + let range = capture.node.byte_range(); + if capture.index == config.name_capture_ix { + name_range = Some(range.clone()); + add_to_signature(range); + } else if Some(capture.index) == config.context_capture_ix + || Some(capture.index) == config.extra_context_capture_ix + { + add_to_signature(range); + } else if capture.index == config.item_capture_ix { + item_range = Some(range.clone()); + } + } + + let language_id = mat.language.id(); + self.matches.advance(); + + if let Some(name_range) = name_range + && let Some(item_range) = item_range + && let Some(signature_start) = signature_start + && let Some(signature_end) = signature_end + { + let name = self + .buffer + .text_for_range(name_range) + .collect::() + .into(); + + return Some(OutlineDeclaration { + identifier: Identifier { name, language_id }, + item_range: item_range, + signature_range: signature_start..signature_end, + parent_index: None, + }); + } + } + None + } +} diff --git a/crates/edit_prediction_context/src/reference.rs b/crates/edit_prediction_context/src/reference.rs new file mode 100644 index 0000000000000000000000000000000000000000..65d34e73bf20f62b24ac2a654af43fc3b83041a9 --- /dev/null +++ b/crates/edit_prediction_context/src/reference.rs @@ -0,0 +1,109 @@ +use language::BufferSnapshot; +use std::collections::HashMap; +use std::ops::Range; + +use crate::{ + excerpt::{EditPredictionExcerpt, EditPredictionExcerptText}, + outline::Identifier, +}; + +#[derive(Debug)] +pub struct Reference { + pub identifier: Identifier, + pub range: Range, + pub region: ReferenceRegion, +} + +#[derive(Copy, Clone, Debug, Eq, PartialEq)] +pub enum ReferenceRegion { + Breadcrumb, + Nearby, +} + +pub fn references_in_excerpt( + excerpt: &EditPredictionExcerpt, + excerpt_text: &EditPredictionExcerptText, + snapshot: &BufferSnapshot, +) -> HashMap> { + let mut references = identifiers_in_range( + excerpt.range.clone(), + excerpt_text.body.as_str(), + ReferenceRegion::Nearby, + snapshot, + ); + + for (range, text) in excerpt + .parent_signature_ranges + .iter() + .zip(excerpt_text.parent_signatures.iter()) + { + references.extend(identifiers_in_range( + range.clone(), + text.as_str(), + ReferenceRegion::Breadcrumb, + snapshot, + )); + } + + let mut identifier_to_references: HashMap> = HashMap::new(); + for reference in references { + identifier_to_references + .entry(reference.identifier.clone()) + .or_insert_with(Vec::new) + .push(reference); + } + identifier_to_references +} + +/// Finds all nodes which have a "variable" match from the highlights query within the offset range. +pub fn identifiers_in_range( + range: Range, + range_text: &str, + reference_region: ReferenceRegion, + buffer: &BufferSnapshot, +) -> Vec { + let mut matches = buffer + .syntax + .matches(range.clone(), &buffer.text, |grammar| { + grammar + .highlights_config + .as_ref() + .map(|config| &config.query) + }); + + let mut references = Vec::new(); + let mut last_added_range = None; + while let Some(mat) = matches.peek() { + let config = matches.grammars()[mat.grammar_index] + .highlights_config + .as_ref(); + + for capture in mat.captures { + if let Some(config) = config { + if config.identifier_capture_indices.contains(&capture.index) { + let node_range = capture.node.byte_range(); + + // sometimes multiple highlight queries match - this deduplicates them + if Some(node_range.clone()) == last_added_range { + continue; + } + + let identifier_text = + &range_text[node_range.start - range.start..node_range.end - range.start]; + references.push(Reference { + identifier: Identifier { + name: identifier_text.into(), + language_id: mat.language.id(), + }, + range: node_range.clone(), + region: reference_region, + }); + last_added_range = Some(node_range); + } + } + } + + matches.advance(); + } + references +} diff --git a/crates/edit_prediction_context/src/tree_sitter_index.rs b/crates/edit_prediction_context/src/tree_sitter_index.rs new file mode 100644 index 0000000000000000000000000000000000000000..4dc00941fe1b8a7a095fffd5605b040001c02eb7 --- /dev/null +++ b/crates/edit_prediction_context/src/tree_sitter_index.rs @@ -0,0 +1,833 @@ +use collections::{HashMap, HashSet}; +use gpui::{App, AppContext as _, Context, Entity, Task, WeakEntity}; +use language::{Buffer, BufferEvent, BufferSnapshot}; +use project::buffer_store::{BufferStore, BufferStoreEvent}; +use project::worktree_store::{WorktreeStore, WorktreeStoreEvent}; +use project::{PathChange, Project, ProjectEntryId, ProjectPath}; +use slotmap::SlotMap; +use std::ops::Range; +use std::sync::Arc; +use text::Anchor; +use util::{ResultExt as _, debug_panic, some_or_debug_panic}; + +use crate::outline::{Identifier, OutlineDeclaration, declarations_in_buffer}; + +// TODO: +// +// * Skip for remote projects + +// Potential future improvements: +// +// * Send multiple selected excerpt ranges. Challenge is that excerpt ranges influence which +// references are present and their scores. + +// Potential future optimizations: +// +// * Cache of buffers for files +// +// * Parse files directly instead of loading into a Rope. Make SyntaxMap generic to handle embedded +// languages? Will also need to find line boundaries, but that can be done by scanning characters in +// the flat representation. +// +// * Use something similar to slotmap without key versions. +// +// * Concurrent slotmap +// +// * Use queue for parsing + +slotmap::new_key_type! { + pub struct DeclarationId; +} + +pub struct TreeSitterIndex { + declarations: SlotMap, + identifiers: HashMap>, + files: HashMap, + buffers: HashMap, BufferState>, + project: WeakEntity, +} + +#[derive(Debug, Default)] +struct FileState { + declarations: Vec, + task: Option>, +} + +#[derive(Default)] +struct BufferState { + declarations: Vec, + task: Option>, +} + +#[derive(Debug, Clone)] +pub enum Declaration { + File { + project_entry_id: ProjectEntryId, + declaration: FileDeclaration, + }, + Buffer { + buffer: WeakEntity, + declaration: BufferDeclaration, + }, +} + +impl Declaration { + fn identifier(&self) -> &Identifier { + match self { + Declaration::File { declaration, .. } => &declaration.identifier, + Declaration::Buffer { declaration, .. } => &declaration.identifier, + } + } +} + +#[derive(Debug, Clone)] +pub struct FileDeclaration { + pub parent: Option, + pub identifier: Identifier, + pub item_range: Range, + pub signature_range: Range, + pub signature_text: Arc, +} + +#[derive(Debug, Clone)] +pub struct BufferDeclaration { + pub parent: Option, + pub identifier: Identifier, + pub item_range: Range, + pub signature_range: Range, +} + +impl TreeSitterIndex { + pub fn new(project: &Entity, cx: &mut Context) -> Self { + let mut this = Self { + declarations: SlotMap::with_key(), + identifiers: HashMap::default(), + project: project.downgrade(), + files: HashMap::default(), + buffers: HashMap::default(), + }; + + let worktree_store = project.read(cx).worktree_store(); + cx.subscribe(&worktree_store, Self::handle_worktree_store_event) + .detach(); + + for worktree in worktree_store + .read(cx) + .worktrees() + .map(|w| w.read(cx).snapshot()) + .collect::>() + { + for entry in worktree.files(false, 0) { + this.update_file( + entry.id, + ProjectPath { + worktree_id: worktree.id(), + path: entry.path.clone(), + }, + cx, + ); + } + } + + let buffer_store = project.read(cx).buffer_store().clone(); + for buffer in buffer_store.read(cx).buffers().collect::>() { + this.register_buffer(&buffer, cx); + } + cx.subscribe(&buffer_store, Self::handle_buffer_store_event) + .detach(); + + this + } + + pub fn declaration(&self, id: DeclarationId) -> Option<&Declaration> { + self.declarations.get(id) + } + + pub fn declarations_for_identifier( + &self, + identifier: Identifier, + cx: &App, + ) -> Vec { + // make sure to not have a large stack allocation + assert!(N < 32); + + let Some(declaration_ids) = self.identifiers.get(&identifier) else { + return vec![]; + }; + + let mut result = Vec::with_capacity(N); + let mut included_buffer_entry_ids = arrayvec::ArrayVec::<_, N>::new(); + let mut file_declarations = Vec::new(); + + for declaration_id in declaration_ids { + let declaration = self.declarations.get(*declaration_id); + let Some(declaration) = some_or_debug_panic(declaration) else { + continue; + }; + match declaration { + Declaration::Buffer { buffer, .. } => { + if let Ok(Some(entry_id)) = buffer.read_with(cx, |buffer, cx| { + project::File::from_dyn(buffer.file()).and_then(|f| f.project_entry_id(cx)) + }) { + included_buffer_entry_ids.push(entry_id); + result.push(declaration.clone()); + if result.len() == N { + return result; + } + } + } + Declaration::File { + project_entry_id, .. + } => { + if !included_buffer_entry_ids.contains(project_entry_id) { + file_declarations.push(declaration.clone()); + } + } + } + } + + for declaration in file_declarations { + match declaration { + Declaration::File { + project_entry_id, .. + } => { + if !included_buffer_entry_ids.contains(&project_entry_id) { + result.push(declaration); + + if result.len() == N { + return result; + } + } + } + Declaration::Buffer { .. } => {} + } + } + + result + } + + fn handle_worktree_store_event( + &mut self, + _worktree_store: Entity, + event: &WorktreeStoreEvent, + cx: &mut Context, + ) { + use WorktreeStoreEvent::*; + match event { + WorktreeUpdatedEntries(worktree_id, updated_entries_set) => { + for (path, entry_id, path_change) in updated_entries_set.iter() { + if let PathChange::Removed = path_change { + self.files.remove(entry_id); + } else { + let project_path = ProjectPath { + worktree_id: *worktree_id, + path: path.clone(), + }; + self.update_file(*entry_id, project_path, cx); + } + } + } + WorktreeDeletedEntry(_worktree_id, project_entry_id) => { + // TODO: Is this needed? + self.files.remove(project_entry_id); + } + _ => {} + } + } + + fn handle_buffer_store_event( + &mut self, + _buffer_store: Entity, + event: &BufferStoreEvent, + cx: &mut Context, + ) { + use BufferStoreEvent::*; + match event { + BufferAdded(buffer) => self.register_buffer(buffer, cx), + BufferOpened { .. } + | BufferChangedFilePath { .. } + | BufferDropped { .. } + | SharedBufferClosed { .. } => {} + } + } + + fn register_buffer(&mut self, buffer: &Entity, cx: &mut Context) { + self.buffers + .insert(buffer.downgrade(), BufferState::default()); + let weak_buf = buffer.downgrade(); + cx.observe_release(buffer, move |this, _buffer, _cx| { + this.buffers.remove(&weak_buf); + }) + .detach(); + cx.subscribe(buffer, Self::handle_buffer_event).detach(); + self.update_buffer(buffer.clone(), cx); + } + + fn handle_buffer_event( + &mut self, + buffer: Entity, + event: &BufferEvent, + cx: &mut Context, + ) { + match event { + BufferEvent::Edited => self.update_buffer(buffer, cx), + _ => {} + } + } + + fn update_buffer(&mut self, buffer: Entity, cx: &Context) { + let mut parse_status = buffer.read(cx).parse_status(); + let snapshot_task = cx.spawn({ + let weak_buffer = buffer.downgrade(); + async move |_, cx| { + while *parse_status.borrow() != language::ParseStatus::Idle { + parse_status.changed().await?; + } + weak_buffer.read_with(cx, |buffer, _cx| buffer.snapshot()) + } + }); + + let parse_task = cx.background_spawn(async move { + let snapshot = snapshot_task.await?; + + anyhow::Ok( + declarations_in_buffer(&snapshot) + .into_iter() + .map(|item| { + ( + item.parent_index, + BufferDeclaration::from_outline(item, &snapshot), + ) + }) + .collect::>(), + ) + }); + + let task = cx.spawn({ + let weak_buffer = buffer.downgrade(); + async move |this, cx| { + let Ok(declarations) = parse_task.await else { + return; + }; + + this.update(cx, |this, _cx| { + let buffer_state = this + .buffers + .entry(weak_buffer.clone()) + .or_insert_with(Default::default); + + for old_declaration_id in &buffer_state.declarations { + let Some(declaration) = this.declarations.remove(*old_declaration_id) + else { + debug_panic!("declaration not found"); + continue; + }; + if let Some(identifier_declarations) = + this.identifiers.get_mut(declaration.identifier()) + { + identifier_declarations.remove(old_declaration_id); + } + } + + let mut new_ids = Vec::with_capacity(declarations.len()); + this.declarations.reserve(declarations.len()); + for (parent_index, mut declaration) in declarations { + declaration.parent = parent_index + .and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied())); + + let identifier = declaration.identifier.clone(); + let declaration_id = this.declarations.insert(Declaration::Buffer { + buffer: weak_buffer.clone(), + declaration, + }); + new_ids.push(declaration_id); + + this.identifiers + .entry(identifier) + .or_default() + .insert(declaration_id); + } + + buffer_state.declarations = new_ids; + }) + .ok(); + } + }); + + self.buffers + .entry(buffer.downgrade()) + .or_insert_with(Default::default) + .task = Some(task); + } + + fn update_file( + &mut self, + entry_id: ProjectEntryId, + project_path: ProjectPath, + cx: &mut Context, + ) { + let Some(project) = self.project.upgrade() else { + return; + }; + let project = project.read(cx); + let Some(worktree) = project.worktree_for_id(project_path.worktree_id, cx) else { + return; + }; + let language_registry = project.languages().clone(); + + let snapshot_task = worktree.update(cx, |worktree, cx| { + let load_task = worktree.load_file(&project_path.path, cx); + cx.spawn(async move |_this, cx| { + let loaded_file = load_task.await?; + let language = language_registry + .language_for_file_path(&project_path.path) + .await + .log_err(); + + let buffer = cx.new(|cx| { + let mut buffer = Buffer::local(loaded_file.text, cx); + buffer.set_language(language, cx); + buffer + })?; + + let mut parse_status = buffer.read_with(cx, |buffer, _| buffer.parse_status())?; + while *parse_status.borrow() != language::ParseStatus::Idle { + parse_status.changed().await?; + } + + buffer.read_with(cx, |buffer, _cx| buffer.snapshot()) + }) + }); + + let parse_task = cx.background_spawn(async move { + let snapshot = snapshot_task.await?; + let declarations = declarations_in_buffer(&snapshot) + .into_iter() + .map(|item| { + ( + item.parent_index, + FileDeclaration::from_outline(item, &snapshot), + ) + }) + .collect::>(); + anyhow::Ok(declarations) + }); + + let task = cx.spawn({ + async move |this, cx| { + // TODO: how to handle errors? + let Ok(declarations) = parse_task.await else { + return; + }; + this.update(cx, |this, _cx| { + let file_state = this.files.entry(entry_id).or_insert_with(Default::default); + + for old_declaration_id in &file_state.declarations { + let Some(declaration) = this.declarations.remove(*old_declaration_id) + else { + debug_panic!("declaration not found"); + continue; + }; + if let Some(identifier_declarations) = + this.identifiers.get_mut(declaration.identifier()) + { + identifier_declarations.remove(old_declaration_id); + } + } + + let mut new_ids = Vec::with_capacity(declarations.len()); + this.declarations.reserve(declarations.len()); + + for (parent_index, mut declaration) in declarations { + declaration.parent = parent_index + .and_then(|ix| some_or_debug_panic(new_ids.get(ix).copied())); + + let identifier = declaration.identifier.clone(); + let declaration_id = this.declarations.insert(Declaration::File { + project_entry_id: entry_id, + declaration, + }); + new_ids.push(declaration_id); + + this.identifiers + .entry(identifier) + .or_default() + .insert(declaration_id); + } + + file_state.declarations = new_ids; + }) + .ok(); + } + }); + + self.files + .entry(entry_id) + .or_insert_with(Default::default) + .task = Some(task); + } +} + +impl BufferDeclaration { + pub fn from_outline(declaration: OutlineDeclaration, snapshot: &BufferSnapshot) -> Self { + // use of anchor_before is a guess that the proper behavior is to expand to include + // insertions immediately before the declaration, but not for insertions immediately after + Self { + parent: None, + identifier: declaration.identifier, + item_range: snapshot.anchor_before(declaration.item_range.start) + ..snapshot.anchor_before(declaration.item_range.end), + signature_range: snapshot.anchor_before(declaration.signature_range.start) + ..snapshot.anchor_before(declaration.signature_range.end), + } + } +} + +impl FileDeclaration { + pub fn from_outline( + declaration: OutlineDeclaration, + snapshot: &BufferSnapshot, + ) -> FileDeclaration { + FileDeclaration { + parent: None, + identifier: declaration.identifier, + item_range: declaration.item_range, + signature_text: snapshot + .text_for_range(declaration.signature_range.clone()) + .collect::() + .into(), + signature_range: declaration.signature_range, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::{path::Path, sync::Arc}; + + use futures::channel::oneshot; + use gpui::TestAppContext; + use indoc::indoc; + use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust}; + use project::{FakeFs, Project, ProjectItem}; + use serde_json::json; + use settings::SettingsStore; + use text::OffsetRangeExt as _; + use util::path; + + use crate::tree_sitter_index::TreeSitterIndex; + + #[gpui::test] + async fn test_unopen_indexed_files(cx: &mut TestAppContext) { + let (project, index, rust_lang_id) = init_test(cx).await; + let main = Identifier { + name: "main".into(), + language_id: rust_lang_id, + }; + + index.read_with(cx, |index, cx| { + let decls = index.declarations_for_identifier::<8>(main.clone(), cx); + assert_eq!(decls.len(), 2); + + let decl = expect_file_decl("c.rs", &decls[0], &project, cx); + assert_eq!(decl.identifier, main.clone()); + assert_eq!(decl.item_range, 32..279); + + let decl = expect_file_decl("a.rs", &decls[1], &project, cx); + assert_eq!(decl.identifier, main); + assert_eq!(decl.item_range, 0..97); + }); + } + + #[gpui::test] + async fn test_parents_in_file(cx: &mut TestAppContext) { + let (project, index, rust_lang_id) = init_test(cx).await; + let test_process_data = Identifier { + name: "test_process_data".into(), + language_id: rust_lang_id, + }; + + index.read_with(cx, |index, cx| { + let decls = index.declarations_for_identifier::<8>(test_process_data.clone(), cx); + assert_eq!(decls.len(), 1); + + let decl = expect_file_decl("c.rs", &decls[0], &project, cx); + assert_eq!(decl.identifier, test_process_data); + + let parent_id = decl.parent.unwrap(); + let parent = index.declaration(parent_id).unwrap(); + let parent_decl = expect_file_decl("c.rs", &parent, &project, cx); + assert_eq!( + parent_decl.identifier, + Identifier { + name: "tests".into(), + language_id: rust_lang_id + } + ); + assert_eq!(parent_decl.parent, None); + }); + } + + #[gpui::test] + async fn test_parents_in_buffer(cx: &mut TestAppContext) { + let (project, index, rust_lang_id) = init_test(cx).await; + let test_process_data = Identifier { + name: "test_process_data".into(), + language_id: rust_lang_id, + }; + + let buffer = project + .update(cx, |project, cx| { + let project_path = project.find_project_path("c.rs", cx).unwrap(); + project.open_buffer(project_path, cx) + }) + .await + .unwrap(); + + cx.run_until_parked(); + + index.read_with(cx, |index, cx| { + let decls = index.declarations_for_identifier::<8>(test_process_data.clone(), cx); + assert_eq!(decls.len(), 1); + + let decl = expect_buffer_decl("c.rs", &decls[0], cx); + assert_eq!(decl.identifier, test_process_data); + + let parent_id = decl.parent.unwrap(); + let parent = index.declaration(parent_id).unwrap(); + let parent_decl = expect_buffer_decl("c.rs", &parent, cx); + assert_eq!( + parent_decl.identifier, + Identifier { + name: "tests".into(), + language_id: rust_lang_id + } + ); + assert_eq!(parent_decl.parent, None); + }); + + drop(buffer); + } + + #[gpui::test] + async fn test_declarations_limt(cx: &mut TestAppContext) { + let (_, index, rust_lang_id) = init_test(cx).await; + + index.read_with(cx, |index, cx| { + let decls = index.declarations_for_identifier::<1>( + Identifier { + name: "main".into(), + language_id: rust_lang_id, + }, + cx, + ); + assert_eq!(decls.len(), 1); + }); + } + + #[gpui::test] + async fn test_buffer_shadow(cx: &mut TestAppContext) { + let (project, index, rust_lang_id) = init_test(cx).await; + + let main = Identifier { + name: "main".into(), + language_id: rust_lang_id, + }; + + let buffer = project + .update(cx, |project, cx| { + let project_path = project.find_project_path("c.rs", cx).unwrap(); + project.open_buffer(project_path, cx) + }) + .await + .unwrap(); + + cx.run_until_parked(); + + index.read_with(cx, |index, cx| { + let decls = index.declarations_for_identifier::<8>(main.clone(), cx); + assert_eq!(decls.len(), 2); + let decl = expect_buffer_decl("c.rs", &decls[0], cx); + assert_eq!(decl.identifier, main); + assert_eq!(decl.item_range.to_offset(&buffer.read(cx)), 32..279); + + expect_file_decl("a.rs", &decls[1], &project, cx); + }); + + // Drop the buffer and wait for release + let (release_tx, release_rx) = oneshot::channel(); + cx.update(|cx| { + cx.observe_release(&buffer, |_, _| { + release_tx.send(()).ok(); + }) + .detach(); + }); + drop(buffer); + cx.run_until_parked(); + release_rx.await.ok(); + cx.run_until_parked(); + + index.read_with(cx, |index, cx| { + let decls = index.declarations_for_identifier::<8>(main, cx); + assert_eq!(decls.len(), 2); + expect_file_decl("c.rs", &decls[0], &project, cx); + expect_file_decl("a.rs", &decls[1], &project, cx); + }); + } + + fn expect_buffer_decl<'a>( + path: &str, + declaration: &'a Declaration, + cx: &App, + ) -> &'a BufferDeclaration { + if let Declaration::Buffer { + declaration, + buffer, + } = declaration + { + assert_eq!( + buffer + .upgrade() + .unwrap() + .read(cx) + .project_path(cx) + .unwrap() + .path + .as_ref(), + Path::new(path), + ); + declaration + } else { + panic!("Expected a buffer declaration, found {:?}", declaration); + } + } + + fn expect_file_decl<'a>( + path: &str, + declaration: &'a Declaration, + project: &Entity, + cx: &App, + ) -> &'a FileDeclaration { + if let Declaration::File { + declaration, + project_entry_id: file, + } = declaration + { + assert_eq!( + project + .read(cx) + .path_for_entry(*file, cx) + .unwrap() + .path + .as_ref(), + Path::new(path), + ); + declaration + } else { + panic!("Expected a file declaration, found {:?}", declaration); + } + } + + async fn init_test( + cx: &mut TestAppContext, + ) -> (Entity, Entity, LanguageId) { + cx.update(|cx| { + let settings_store = SettingsStore::test(cx); + cx.set_global(settings_store); + language::init(cx); + Project::init_settings(cx); + }); + + let fs = FakeFs::new(cx.executor()); + fs.insert_tree( + path!("/root"), + json!({ + "a.rs": indoc! {r#" + fn main() { + let x = 1; + let y = 2; + let z = add(x, y); + println!("Result: {}", z); + } + + fn add(a: i32, b: i32) -> i32 { + a + b + } + "#}, + "b.rs": indoc! {" + pub struct Config { + pub name: String, + pub value: i32, + } + + impl Config { + pub fn new(name: String, value: i32) -> Self { + Config { name, value } + } + } + "}, + "c.rs": indoc! {r#" + use std::collections::HashMap; + + fn main() { + let args: Vec = std::env::args().collect(); + let data: Vec = args[1..] + .iter() + .filter_map(|s| s.parse().ok()) + .collect(); + let result = process_data(data); + println!("{:?}", result); + } + + fn process_data(data: Vec) -> HashMap { + let mut counts = HashMap::new(); + for value in data { + *counts.entry(value).or_insert(0) += 1; + } + counts + } + + #[cfg(test)] + mod tests { + use super::*; + + #[test] + fn test_process_data() { + let data = vec![1, 2, 2, 3]; + let result = process_data(data); + assert_eq!(result.get(&2), Some(&2)); + } + } + "#} + }), + ) + .await; + let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await; + let language_registry = project.read_with(cx, |project, _| project.languages().clone()); + let lang = rust_lang(); + let lang_id = lang.id(); + language_registry.add(Arc::new(lang)); + + let index = cx.new(|cx| TreeSitterIndex::new(&project, cx)); + cx.run_until_parked(); + + (project, index, lang_id) + } + + fn rust_lang() -> Language { + Language::new( + LanguageConfig { + name: "Rust".into(), + matcher: LanguageMatcher { + path_suffixes: vec!["rs".to_string()], + ..Default::default() + }, + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_outline_query(include_str!("../../languages/src/rust/outline.scm")) + .unwrap() + } +} diff --git a/crates/gpui/Cargo.toml b/crates/gpui/Cargo.toml index 44f819c135298dc991ad6036ad9948b5eaf609a4..ac1bdf85cb478064db42b3dccde8e44adee72fdd 100644 --- a/crates/gpui/Cargo.toml +++ b/crates/gpui/Cargo.toml @@ -115,7 +115,7 @@ seahash = "4.1" semantic_version.workspace = true serde.workspace = true serde_json.workspace = true -slotmap = "1.0.6" +slotmap.workspace = true smallvec.workspace = true smol.workspace = true stacksafe.workspace = true diff --git a/crates/language/src/buffer.rs b/crates/language/src/buffer.rs index 77270807644830a233cbd6f3c47e4912ff47f543..c94153ba00a29b40544b3c685ad4cfb1add3db1b 100644 --- a/crates/language/src/buffer.rs +++ b/crates/language/src/buffer.rs @@ -145,7 +145,7 @@ struct BufferBranchState { /// state of a buffer. pub struct BufferSnapshot { pub text: text::BufferSnapshot, - pub(crate) syntax: SyntaxSnapshot, + pub syntax: SyntaxSnapshot, file: Option>, diagnostics: SmallVec<[(LanguageServerId, DiagnosticSet); 2]>, remote_selections: TreeMap, @@ -660,7 +660,10 @@ impl HighlightedTextBuilder { syntax_snapshot: &'a SyntaxSnapshot, ) -> BufferChunks<'a> { let captures = syntax_snapshot.captures(range.clone(), snapshot, |grammar| { - grammar.highlights_query.as_ref() + grammar + .highlights_config + .as_ref() + .map(|config| &config.query) }); let highlight_maps = captures @@ -3246,7 +3249,10 @@ impl BufferSnapshot { fn get_highlights(&self, range: Range) -> (SyntaxMapCaptures<'_>, Vec) { let captures = self.syntax.captures(range, &self.text, |grammar| { - grammar.highlights_query.as_ref() + grammar + .highlights_config + .as_ref() + .map(|config| &config.query) }); let highlight_maps = captures .grammars() diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index 77e8ee0232819a830e48d1d12a778ef19026d7b6..3c951e50ff231a72e284da743bb3e5d409eb9c5e 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -81,7 +81,9 @@ pub use language_registry::{ }; pub use lsp::{LanguageServerId, LanguageServerName}; pub use outline::*; -pub use syntax_map::{OwnedSyntaxLayer, SyntaxLayer, ToTreeSitterPoint, TreeSitterOptions}; +pub use syntax_map::{ + OwnedSyntaxLayer, SyntaxLayer, SyntaxMapMatches, ToTreeSitterPoint, TreeSitterOptions, +}; pub use text::{AnchorRangeExt, LineEnding}; pub use tree_sitter::{Node, Parser, Tree, TreeCursor}; @@ -1154,7 +1156,7 @@ pub struct Grammar { id: GrammarId, pub ts_language: tree_sitter::Language, pub(crate) error_query: Option, - pub(crate) highlights_query: Option, + pub highlights_config: Option, pub(crate) brackets_config: Option, pub(crate) redactions_config: Option, pub(crate) runnable_config: Option, @@ -1168,6 +1170,11 @@ pub struct Grammar { pub(crate) highlight_map: Mutex, } +pub struct HighlightsConfig { + pub query: Query, + pub identifier_capture_indices: Vec, +} + struct IndentConfig { query: Query, indent_capture_ix: u32, @@ -1332,7 +1339,7 @@ impl Language { grammar: ts_language.map(|ts_language| { Arc::new(Grammar { id: GrammarId::new(), - highlights_query: None, + highlights_config: None, brackets_config: None, outline_config: None, text_object_config: None, @@ -1430,7 +1437,29 @@ impl Language { pub fn with_highlights_query(mut self, source: &str) -> Result { let grammar = self.grammar_mut()?; - grammar.highlights_query = Some(Query::new(&grammar.ts_language, source)?); + let query = Query::new(&grammar.ts_language, source)?; + + let mut identifier_capture_indices = Vec::new(); + for name in [ + "variable", + "constant", + "constructor", + "function", + "function.method", + "function.method.call", + "function.special", + "property", + "type", + "type.interface", + ] { + identifier_capture_indices.extend(query.capture_index_for_name(name)); + } + + grammar.highlights_config = Some(HighlightsConfig { + query, + identifier_capture_indices, + }); + Ok(self) } @@ -1856,7 +1885,10 @@ impl Language { let tree = grammar.parse_text(text, None); let captures = SyntaxSnapshot::single_tree_captures(range.clone(), text, &tree, self, |grammar| { - grammar.highlights_query.as_ref() + grammar + .highlights_config + .as_ref() + .map(|config| &config.query) }); let highlight_maps = vec![grammar.highlight_map()]; let mut offset = 0; @@ -1885,10 +1917,10 @@ impl Language { pub fn set_theme(&self, theme: &SyntaxTheme) { if let Some(grammar) = self.grammar.as_ref() - && let Some(highlights_query) = &grammar.highlights_query + && let Some(highlights_config) = &grammar.highlights_config { *grammar.highlight_map.lock() = - HighlightMap::new(highlights_query.capture_names(), theme); + HighlightMap::new(highlights_config.query.capture_names(), theme); } } @@ -2103,8 +2135,9 @@ impl Grammar { pub fn highlight_id_for_name(&self, name: &str) -> Option { let capture_id = self - .highlights_query + .highlights_config .as_ref()? + .query .capture_index_for_name(name)?; Some(self.highlight_map.lock().get(capture_id)) } diff --git a/crates/language/src/syntax_map/syntax_map_tests.rs b/crates/language/src/syntax_map/syntax_map_tests.rs index 622731b7814ce16bfcc026b6723e80d5ba4dda7a..6b19d651e241ad71229c6c7fc429883a44367304 100644 --- a/crates/language/src/syntax_map/syntax_map_tests.rs +++ b/crates/language/src/syntax_map/syntax_map_tests.rs @@ -1409,12 +1409,15 @@ fn assert_capture_ranges( ) { let mut actual_ranges = Vec::>::new(); let captures = syntax_map.captures(0..buffer.len(), buffer, |grammar| { - grammar.highlights_query.as_ref() + grammar + .highlights_config + .as_ref() + .map(|config| &config.query) }); let queries = captures .grammars() .iter() - .map(|grammar| grammar.highlights_query.as_ref().unwrap()) + .map(|grammar| &grammar.highlights_config.as_ref().unwrap().query) .collect::>(); for capture in captures { let name = &queries[capture.grammar_index].capture_names()[capture.index as usize]; diff --git a/crates/util/src/util.rs b/crates/util/src/util.rs index 90f5be1c92875ac0b9b2d3e7352ae858371b3686..ee18093784e4b0f2b78db9240d918685b6f01b6f 100644 --- a/crates/util/src/util.rs +++ b/crates/util/src/util.rs @@ -1095,6 +1095,15 @@ impl From> for ConnectionResult { } } +#[track_caller] +pub fn some_or_debug_panic(option: Option) -> Option { + #[cfg(debug_assertions)] + if option.is_none() { + panic!("Unexpected None"); + } + option +} + #[cfg(test)] mod tests { use super::*;