diff --git a/Cargo.lock b/Cargo.lock index 8aba19b5c0ee2777fb0809956712bbaf74997c5d..57e4cad919c0b6f37abe2e5a7b49e973af3fcd9c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5189,6 +5189,9 @@ dependencies = [ "strum 0.27.1", "text", "tree-sitter", + "tree-sitter-c", + "tree-sitter-cpp", + "tree-sitter-go", "workspace-hack", "zed-collections", "zed-util", @@ -16964,8 +16967,7 @@ dependencies = [ [[package]] name = "tree-sitter-typescript" version = "0.23.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c5f76ed8d947a75cc446d5fccd8b602ebf0cde64ccf2ffa434d873d7a575eff" +source = "git+https://github.com/zed-industries/tree-sitter-typescript?rev=e2c53597d6a5d9cf7bbe8dccde576fe1e46c5899#e2c53597d6a5d9cf7bbe8dccde576fe1e46c5899" dependencies = [ "cc", "tree-sitter-language", @@ -20785,6 +20787,7 @@ dependencies = [ "terminal_view", "watch", "workspace-hack", + "zed-collections", "zed-util", "zeta", "zeta2", diff --git a/Cargo.toml b/Cargo.toml index 87f912c6be8df1a5d93e6622b041c58d8f66e75f..34214772f5f33f9b3521e6d1ae2744857cf1c2fa 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -693,7 +693,7 @@ tree-sitter-python = "0.25" tree-sitter-regex = "0.24" tree-sitter-ruby = "0.23" tree-sitter-rust = "0.24" -tree-sitter-typescript = "0.23" +tree-sitter-typescript = { git = "https://github.com/zed-industries/tree-sitter-typescript", rev = "e2c53597d6a5d9cf7bbe8dccde576fe1e46c5899" } # https://github.com/tree-sitter/tree-sitter-typescript/pull/347 tree-sitter-yaml = { git = "https://github.com/zed-industries/tree-sitter-yaml", rev = "baff0b51c64ef6a1fb1f8390f3ad6015b83ec13a" } unicase = "2.6" unicode-script = "0.5.7" diff --git a/crates/cloud_llm_client/src/predict_edits_v3.rs b/crates/cloud_llm_client/src/predict_edits_v3.rs index 90df92f54216c9040c3a36b737bcf9415901ee87..ce53de99efbe801e3bf2fc37f9acc423d6737d1e 100644 --- a/crates/cloud_llm_client/src/predict_edits_v3.rs +++ b/crates/cloud_llm_client/src/predict_edits_v3.rs @@ -127,7 +127,6 @@ pub struct DeclarationScoreComponents { pub declaration_count: usize, pub reference_line_distance: u32, pub declaration_line_distance: u32, - pub declaration_line_distance_rank: usize, pub excerpt_vs_item_jaccard: f32, pub excerpt_vs_signature_jaccard: f32, pub adjacent_vs_item_jaccard: f32, @@ -136,6 +135,13 @@ pub struct DeclarationScoreComponents { pub excerpt_vs_signature_weighted_overlap: f32, pub adjacent_vs_item_weighted_overlap: f32, pub adjacent_vs_signature_weighted_overlap: f32, + pub path_import_match_count: usize, + pub wildcard_path_import_match_count: usize, + pub import_similarity: f32, + pub max_import_similarity: f32, + pub normalized_import_similarity: f32, + pub wildcard_import_similarity: f32, + pub normalized_wildcard_import_similarity: f32, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/crates/edit_prediction_context/Cargo.toml b/crates/edit_prediction_context/Cargo.toml index c34386b3fb77565e627887155055917ed6ceb40c..c2e80d4a1a4fcc04dcc05c81369a9d9a2155954e 100644 --- a/crates/edit_prediction_context/Cargo.toml +++ b/crates/edit_prediction_context/Cargo.toml @@ -19,6 +19,7 @@ collections.workspace = true futures.workspace = true gpui.workspace = true hashbrown.workspace = true +indoc.workspace = true itertools.workspace = true language.workspace = true log.workspace = true @@ -45,5 +46,8 @@ project = {workspace= true, features = ["test-support"]} serde_json.workspace = true settings = {workspace= true, features = ["test-support"]} text = { workspace = true, features = ["test-support"] } +tree-sitter-c.workspace = true +tree-sitter-cpp.workspace = true +tree-sitter-go.workspace = true util = { workspace = true, features = ["test-support"] } zlog.workspace = true diff --git a/crates/edit_prediction_context/src/declaration.rs b/crates/edit_prediction_context/src/declaration.rs index a6efe63fc606580311d6e7653bb5ee98a80fb9d3..b57054cb537655184d4a52b511213dcfa570cd87 100644 --- a/crates/edit_prediction_context/src/declaration.rs +++ b/crates/edit_prediction_context/src/declaration.rs @@ -1,9 +1,11 @@ -use language::LanguageId; +use language::{Language, LanguageId}; use project::ProjectEntryId; -use std::borrow::Cow; use std::ops::Range; use std::sync::Arc; +use std::{borrow::Cow, path::Path}; use text::{Bias, BufferId, Rope}; +use util::paths::{path_ends_with, strip_path_suffix}; +use util::rel_path::RelPath; use crate::outline::OutlineDeclaration; @@ -22,12 +24,14 @@ pub enum Declaration { File { project_entry_id: ProjectEntryId, declaration: FileDeclaration, + cached_path: CachedDeclarationPath, }, Buffer { project_entry_id: ProjectEntryId, buffer_id: BufferId, rope: Rope, declaration: BufferDeclaration, + cached_path: CachedDeclarationPath, }, } @@ -73,6 +77,13 @@ impl Declaration { } } + pub fn cached_path(&self) -> &CachedDeclarationPath { + match self { + Declaration::File { cached_path, .. } => cached_path, + Declaration::Buffer { cached_path, .. } => cached_path, + } + } + pub fn item_range(&self) -> Range { match self { Declaration::File { declaration, .. } => declaration.item_range.clone(), @@ -235,3 +246,69 @@ impl BufferDeclaration { } } } + +#[derive(Debug, Clone)] +pub struct CachedDeclarationPath { + pub worktree_abs_path: Arc, + pub rel_path: Arc, + /// The relative path of the file, possibly stripped according to `import_path_strip_regex`. + pub rel_path_after_regex_stripping: Arc, +} + +impl CachedDeclarationPath { + pub fn new( + worktree_abs_path: Arc, + path: &Arc, + language: Option<&Arc>, + ) -> Self { + let rel_path = path.clone(); + let rel_path_after_regex_stripping = if let Some(language) = language + && let Some(strip_regex) = language.config().import_path_strip_regex.as_ref() + && let Ok(stripped) = RelPath::unix(&Path::new( + strip_regex.replace_all(rel_path.as_unix_str(), "").as_ref(), + )) { + Arc::from(stripped) + } else { + rel_path.clone() + }; + CachedDeclarationPath { + worktree_abs_path, + rel_path, + rel_path_after_regex_stripping, + } + } + + #[cfg(test)] + pub fn new_for_test(worktree_abs_path: &str, rel_path: &str) -> Self { + let rel_path: Arc = util::rel_path::rel_path(rel_path).into(); + CachedDeclarationPath { + worktree_abs_path: std::path::PathBuf::from(worktree_abs_path).into(), + rel_path_after_regex_stripping: rel_path.clone(), + rel_path, + } + } + + pub fn ends_with_posix_path(&self, path: &Path) -> bool { + if path.as_os_str().len() <= self.rel_path_after_regex_stripping.as_unix_str().len() { + path_ends_with(self.rel_path_after_regex_stripping.as_std_path(), path) + } else { + if let Some(remaining) = + strip_path_suffix(path, self.rel_path_after_regex_stripping.as_std_path()) + { + path_ends_with(&self.worktree_abs_path, remaining) + } else { + false + } + } + } + + pub fn equals_absolute_path(&self, path: &Path) -> bool { + if let Some(remaining) = + strip_path_suffix(path, &self.rel_path_after_regex_stripping.as_std_path()) + { + self.worktree_abs_path.as_ref() == remaining + } else { + false + } + } +} diff --git a/crates/edit_prediction_context/src/declaration_scoring.rs b/crates/edit_prediction_context/src/declaration_scoring.rs index 6f027ed1f63cdd2688cd149edcc19f7a8fbc704f..6531fa746240df9466c35b246c804a2827e3607f 100644 --- a/crates/edit_prediction_context/src/declaration_scoring.rs +++ b/crates/edit_prediction_context/src/declaration_scoring.rs @@ -1,15 +1,15 @@ use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents; use collections::HashMap; -use itertools::Itertools as _; use language::BufferSnapshot; use ordered_float::OrderedFloat; use serde::Serialize; -use std::{cmp::Reverse, ops::Range}; +use std::{cmp::Reverse, ops::Range, path::Path, sync::Arc}; use strum::EnumIter; use text::{Point, ToPoint}; use crate::{ - Declaration, EditPredictionExcerpt, Identifier, + CachedDeclarationPath, Declaration, EditPredictionExcerpt, Identifier, + imports::{Import, Imports, Module}, reference::{Reference, ReferenceRegion}, syntax_index::SyntaxIndexState, text_similarity::{Occurrences, jaccard_similarity, weighted_overlap_coefficient}, @@ -17,12 +17,17 @@ use crate::{ const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16; +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct EditPredictionScoreOptions { + pub omit_excerpt_overlaps: bool, +} + #[derive(Clone, Debug)] pub struct ScoredDeclaration { + /// identifier used by the local reference pub identifier: Identifier, pub declaration: Declaration, - pub score_components: DeclarationScoreComponents, - pub scores: DeclarationScores, + pub components: DeclarationScoreComponents, } #[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)] @@ -31,12 +36,55 @@ pub enum DeclarationStyle { Declaration, } +#[derive(Clone, Debug, Serialize, Default)] +pub struct DeclarationScores { + pub signature: f32, + pub declaration: f32, + pub retrieval: f32, +} + impl ScoredDeclaration { /// Returns the score for this declaration with the specified style. pub fn score(&self, style: DeclarationStyle) -> f32 { + // TODO: handle truncation + + // Score related to how likely this is the correct declaration, range 0 to 1 + let retrieval = self.retrieval_score(); + + // Score related to the distance between the reference and cursor, range 0 to 1 + let distance_score = if self.components.is_referenced_nearby { + 1.0 / (1.0 + self.components.reference_line_distance as f32 / 10.0).powf(2.0) + } else { + // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures + 0.5 + }; + + // For now instead of linear combination, the scores are just multiplied together. + let combined_score = 10.0 * retrieval * distance_score; + match style { - DeclarationStyle::Signature => self.scores.signature, - DeclarationStyle::Declaration => self.scores.declaration, + DeclarationStyle::Signature => { + combined_score * self.components.excerpt_vs_signature_weighted_overlap + } + DeclarationStyle::Declaration => { + 2.0 * combined_score * self.components.excerpt_vs_item_weighted_overlap + } + } + } + + pub fn retrieval_score(&self) -> f32 { + if self.components.is_same_file { + 10.0 / self.components.same_file_declaration_count as f32 + } else if self.components.path_import_match_count > 0 { + 3.0 + } else if self.components.wildcard_path_import_match_count > 0 { + 1.0 + } else if self.components.normalized_import_similarity > 0.0 { + self.components.normalized_import_similarity + } else if self.components.normalized_wildcard_import_similarity > 0.0 { + 0.5 * self.components.normalized_wildcard_import_similarity + } else { + 1.0 / self.components.declaration_count as f32 } } @@ -54,100 +102,215 @@ impl ScoredDeclaration { } pub fn score_density(&self, style: DeclarationStyle) -> f32 { - self.score(style) / (self.size(style)) as f32 + self.score(style) / self.size(style) as f32 } } pub fn scored_declarations( + options: &EditPredictionScoreOptions, index: &SyntaxIndexState, excerpt: &EditPredictionExcerpt, excerpt_occurrences: &Occurrences, adjacent_occurrences: &Occurrences, + imports: &Imports, identifier_to_references: HashMap>, cursor_offset: usize, current_buffer: &BufferSnapshot, ) -> Vec { let cursor_point = cursor_offset.to_point(¤t_buffer); + let mut wildcard_import_occurrences = Vec::new(); + let mut wildcard_import_paths = Vec::new(); + for wildcard_import in imports.wildcard_modules.iter() { + match wildcard_import { + Module::Namespace(namespace) => { + wildcard_import_occurrences.push(namespace.occurrences()) + } + Module::SourceExact(path) => wildcard_import_paths.push(path), + Module::SourceFuzzy(path) => { + wildcard_import_occurrences.push(Occurrences::from_path(&path)) + } + } + } + let mut declarations = identifier_to_references .into_iter() .flat_map(|(identifier, references)| { - let declarations = - index.declarations_for_identifier::(&identifier); + let mut import_occurrences = Vec::new(); + let mut import_paths = Vec::new(); + let mut found_external_identifier: Option<&Identifier> = None; + + if let Some(imports) = imports.identifier_to_imports.get(&identifier) { + // only use alias when it's the only import, could be generalized if some language + // has overlapping aliases + // + // TODO: when an aliased declaration is included in the prompt, should include the + // aliasing in the prompt. + // + // TODO: For SourceFuzzy consider having componentwise comparison that pays + // attention to ordering. + if let [ + Import::Alias { + module, + external_identifier, + }, + ] = imports.as_slice() + { + match module { + Module::Namespace(namespace) => { + import_occurrences.push(namespace.occurrences()) + } + Module::SourceExact(path) => import_paths.push(path), + Module::SourceFuzzy(path) => { + import_occurrences.push(Occurrences::from_path(&path)) + } + } + found_external_identifier = Some(&external_identifier); + } else { + for import in imports { + match import { + Import::Direct { module } => match module { + Module::Namespace(namespace) => { + import_occurrences.push(namespace.occurrences()) + } + Module::SourceExact(path) => import_paths.push(path), + Module::SourceFuzzy(path) => { + import_occurrences.push(Occurrences::from_path(&path)) + } + }, + Import::Alias { .. } => {} + } + } + } + } + + let identifier_to_lookup = found_external_identifier.unwrap_or(&identifier); + // TODO: update this to be able to return more declarations? Especially if there is the + // ability to quickly filter a large list (based on imports) + let declarations = index + .declarations_for_identifier::( + &identifier_to_lookup, + ); let declaration_count = declarations.len(); - declarations - .into_iter() - .filter_map(|(declaration_id, declaration)| match declaration { + if declaration_count == 0 { + return Vec::new(); + } + + // TODO: option to filter out other candidates when same file / import match + let mut checked_declarations = Vec::new(); + for (declaration_id, declaration) in declarations { + match declaration { Declaration::Buffer { buffer_id, declaration: buffer_declaration, .. } => { - let is_same_file = buffer_id == ¤t_buffer.remote_id(); - - if is_same_file { - let overlaps_excerpt = + if buffer_id == ¤t_buffer.remote_id() { + let already_included_in_prompt = range_intersection(&buffer_declaration.item_range, &excerpt.range) - .is_some(); - if overlaps_excerpt - || excerpt - .parent_declarations - .iter() - .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id) - { - None - } else { + .is_some() + || excerpt.parent_declarations.iter().any( + |(excerpt_parent, _)| excerpt_parent == &declaration_id, + ); + if !options.omit_excerpt_overlaps || !already_included_in_prompt { let declaration_line = buffer_declaration .item_range .start .to_point(current_buffer) .row; - Some(( - true, - (cursor_point.row as i32 - declaration_line as i32) - .unsigned_abs(), + let declaration_line_distance = (cursor_point.row as i32 + - declaration_line as i32) + .unsigned_abs(); + checked_declarations.push(CheckedDeclaration { declaration, - )) + same_file_line_distance: Some(declaration_line_distance), + path_import_match_count: 0, + wildcard_path_import_match_count: 0, + }); } + continue; } else { - Some((false, u32::MAX, declaration)) } } - Declaration::File { .. } => { - // We can assume that a file declaration is in a different file, - // because the current one must be open - Some((false, u32::MAX, declaration)) + Declaration::File { .. } => {} + } + let declaration_path = declaration.cached_path(); + let path_import_match_count = import_paths + .iter() + .filter(|import_path| { + declaration_path_matches_import(&declaration_path, import_path) + }) + .count(); + let wildcard_path_import_match_count = wildcard_import_paths + .iter() + .filter(|import_path| { + declaration_path_matches_import(&declaration_path, import_path) + }) + .count(); + checked_declarations.push(CheckedDeclaration { + declaration, + same_file_line_distance: None, + path_import_match_count, + wildcard_path_import_match_count, + }); + } + + let mut max_import_similarity = 0.0; + let mut max_wildcard_import_similarity = 0.0; + + let mut scored_declarations_for_identifier = checked_declarations + .into_iter() + .map(|checked_declaration| { + let same_file_declaration_count = + index.file_declaration_count(checked_declaration.declaration); + + let declaration = score_declaration( + &identifier, + &references, + checked_declaration, + same_file_declaration_count, + declaration_count, + &excerpt_occurrences, + &adjacent_occurrences, + &import_occurrences, + &wildcard_import_occurrences, + cursor_point, + current_buffer, + ); + + if declaration.components.import_similarity > max_import_similarity { + max_import_similarity = declaration.components.import_similarity; + } + + if declaration.components.wildcard_import_similarity + > max_wildcard_import_similarity + { + max_wildcard_import_similarity = + declaration.components.wildcard_import_similarity; } + + declaration }) - .sorted_by_key(|&(_, distance, _)| distance) - .enumerate() - .map( - |( - declaration_line_distance_rank, - (is_same_file, declaration_line_distance, declaration), - )| { - let same_file_declaration_count = index.file_declaration_count(declaration); - - score_declaration( - &identifier, - &references, - declaration.clone(), - is_same_file, - declaration_line_distance, - declaration_line_distance_rank, - same_file_declaration_count, - declaration_count, - &excerpt_occurrences, - &adjacent_occurrences, - cursor_point, - current_buffer, - ) - }, - ) - .collect::>() + .collect::>(); + + if max_import_similarity > 0.0 || max_wildcard_import_similarity > 0.0 { + for declaration in scored_declarations_for_identifier.iter_mut() { + if max_import_similarity > 0.0 { + declaration.components.max_import_similarity = max_import_similarity; + declaration.components.normalized_import_similarity = + declaration.components.import_similarity / max_import_similarity; + } + if max_wildcard_import_similarity > 0.0 { + declaration.components.normalized_wildcard_import_similarity = + declaration.components.wildcard_import_similarity + / max_wildcard_import_similarity; + } + } + } + + scored_declarations_for_identifier }) - .flatten() .collect::>(); declarations.sort_unstable_by_key(|declaration| { @@ -160,6 +323,24 @@ pub fn scored_declarations( declarations } +struct CheckedDeclaration<'a> { + declaration: &'a Declaration, + same_file_line_distance: Option, + path_import_match_count: usize, + wildcard_path_import_match_count: usize, +} + +fn declaration_path_matches_import( + declaration_path: &CachedDeclarationPath, + import_path: &Arc, +) -> bool { + if import_path.is_absolute() { + declaration_path.equals_absolute_path(import_path) + } else { + declaration_path.ends_with_posix_path(import_path) + } +} + fn range_intersection(a: &Range, b: &Range) -> Option> { let start = a.start.clone().max(b.start.clone()); let end = a.end.clone().min(b.end.clone()); @@ -173,17 +354,23 @@ fn range_intersection(a: &Range, b: &Range) -> Option Option { +) -> ScoredDeclaration { + let CheckedDeclaration { + declaration, + same_file_line_distance, + path_import_match_count, + wildcard_path_import_match_count, + } = checked_declaration; + let is_referenced_nearby = references .iter() .any(|r| r.region == ReferenceRegion::Nearby); @@ -200,6 +387,9 @@ fn score_declaration( .min() .unwrap(); + let is_same_file = same_file_line_distance.is_some(); + let declaration_line_distance = same_file_line_distance.unwrap_or(u32::MAX); + let item_source_occurrences = Occurrences::within_string(&declaration.item_text().0); let item_signature_occurrences = Occurrences::within_string(&declaration.signature_text().0); let excerpt_vs_item_jaccard = jaccard_similarity(excerpt_occurrences, &item_source_occurrences); @@ -219,6 +409,37 @@ fn score_declaration( let adjacent_vs_signature_weighted_overlap = weighted_overlap_coefficient(adjacent_occurrences, &item_signature_occurrences); + let mut import_similarity = 0f32; + let mut wildcard_import_similarity = 0f32; + if !import_occurrences.is_empty() || !wildcard_import_occurrences.is_empty() { + let cached_path = declaration.cached_path(); + let path_occurrences = Occurrences::from_worktree_path( + cached_path + .worktree_abs_path + .file_name() + .map(|f| f.to_string_lossy()), + &cached_path.rel_path, + ); + import_similarity = import_occurrences + .iter() + .map(|namespace_occurrences| { + OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences)) + }) + .max() + .map(|similarity| similarity.into_inner()) + .unwrap_or_default(); + + // TODO: Consider something other than max + wildcard_import_similarity = wildcard_import_occurrences + .iter() + .map(|namespace_occurrences| { + OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences)) + }) + .max() + .map(|similarity| similarity.into_inner()) + .unwrap_or_default(); + } + // TODO: Consider adding declaration_file_count let score_components = DeclarationScoreComponents { is_same_file, @@ -226,7 +447,6 @@ fn score_declaration( is_referenced_in_breadcrumb, reference_line_distance, declaration_line_distance, - declaration_line_distance_rank, reference_count, same_file_declaration_count, declaration_count, @@ -238,52 +458,59 @@ fn score_declaration( excerpt_vs_signature_weighted_overlap, adjacent_vs_item_weighted_overlap, adjacent_vs_signature_weighted_overlap, + path_import_match_count, + wildcard_path_import_match_count, + import_similarity, + max_import_similarity: 0.0, + normalized_import_similarity: 0.0, + wildcard_import_similarity, + normalized_wildcard_import_similarity: 0.0, }; - Some(ScoredDeclaration { + ScoredDeclaration { identifier: identifier.clone(), - declaration: declaration, - scores: DeclarationScores::score(&score_components), - score_components, - }) + declaration: declaration.clone(), + components: score_components, + } } -#[derive(Clone, Debug, Serialize)] -pub struct DeclarationScores { - pub signature: f32, - pub declaration: f32, - pub retrieval: f32, -} +#[cfg(test)] +mod test { + use super::*; -impl DeclarationScores { - fn score(components: &DeclarationScoreComponents) -> DeclarationScores { - // TODO: handle truncation + #[test] + fn test_declaration_path_matches() { + let declaration_path = + CachedDeclarationPath::new_for_test("/home/user/project", "src/maths.ts"); - // Score related to how likely this is the correct declaration, range 0 to 1 - let retrieval = if components.is_same_file { - // TODO: use declaration_line_distance_rank - 1.0 / components.same_file_declaration_count as f32 - } else { - 1.0 / components.declaration_count as f32 - }; + assert!(declaration_path_matches_import( + &declaration_path, + &Path::new("maths.ts").into() + )); - // Score related to the distance between the reference and cursor, range 0 to 1 - let distance_score = if components.is_referenced_nearby { - 1.0 / (1.0 + components.reference_line_distance as f32 / 10.0).powf(2.0) - } else { - // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures - 0.5 - }; + assert!(declaration_path_matches_import( + &declaration_path, + &Path::new("project/src/maths.ts").into() + )); - // For now instead of linear combination, the scores are just multiplied together. - let combined_score = 10.0 * retrieval * distance_score; + assert!(declaration_path_matches_import( + &declaration_path, + &Path::new("user/project/src/maths.ts").into() + )); - DeclarationScores { - signature: combined_score * components.excerpt_vs_signature_weighted_overlap, - // declaration score gets boosted both by being multiplied by 2 and by there being more - // weighted overlap. - declaration: 2.0 * combined_score * components.excerpt_vs_item_weighted_overlap, - retrieval, - } + assert!(declaration_path_matches_import( + &declaration_path, + &Path::new("/home/user/project/src/maths.ts").into() + )); + + assert!(!declaration_path_matches_import( + &declaration_path, + &Path::new("other.ts").into() + )); + + assert!(!declaration_path_matches_import( + &declaration_path, + &Path::new("/home/user/project/src/other.ts").into() + )); } } diff --git a/crates/edit_prediction_context/src/edit_prediction_context.rs b/crates/edit_prediction_context/src/edit_prediction_context.rs index c994caf7546fdb22539e9d60ff976d4379ed2cc8..19cafe0412bb0db67ef906d1ff119d7c23234f78 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context.rs @@ -1,12 +1,13 @@ mod declaration; mod declaration_scoring; mod excerpt; +mod imports; mod outline; mod reference; mod syntax_index; pub mod text_similarity; -use std::sync::Arc; +use std::{path::Path, sync::Arc}; use collections::HashMap; use gpui::{App, AppContext as _, Entity, Task}; @@ -16,9 +17,17 @@ use text::{Point, ToOffset as _}; pub use declaration::*; pub use declaration_scoring::*; pub use excerpt::*; +pub use imports::*; pub use reference::*; pub use syntax_index::*; +#[derive(Clone, Debug, PartialEq)] +pub struct EditPredictionContextOptions { + pub use_imports: bool, + pub excerpt: EditPredictionExcerptOptions, + pub score: EditPredictionScoreOptions, +} + #[derive(Clone, Debug)] pub struct EditPredictionContext { pub excerpt: EditPredictionExcerpt, @@ -31,21 +40,34 @@ impl EditPredictionContext { pub fn gather_context_in_background( cursor_point: Point, buffer: BufferSnapshot, - excerpt_options: EditPredictionExcerptOptions, + options: EditPredictionContextOptions, syntax_index: Option>, cx: &mut App, ) -> Task> { + let parent_abs_path = project::File::from_dyn(buffer.file()).and_then(|f| { + let mut path = f.worktree.read(cx).absolutize(&f.path); + if path.pop() { Some(path) } else { None } + }); + if let Some(syntax_index) = syntax_index { let index_state = syntax_index.read_with(cx, |index, _cx| Arc::downgrade(index.state())); cx.background_spawn(async move { + let parent_abs_path = parent_abs_path.as_deref(); let index_state = index_state.upgrade()?; let index_state = index_state.lock().await; - Self::gather_context(cursor_point, &buffer, &excerpt_options, Some(&index_state)) + Self::gather_context( + cursor_point, + &buffer, + parent_abs_path, + &options, + Some(&index_state), + ) }) } else { cx.background_spawn(async move { - Self::gather_context(cursor_point, &buffer, &excerpt_options, None) + let parent_abs_path = parent_abs_path.as_deref(); + Self::gather_context(cursor_point, &buffer, parent_abs_path, &options, None) }) } } @@ -53,13 +75,20 @@ impl EditPredictionContext { pub fn gather_context( cursor_point: Point, buffer: &BufferSnapshot, - excerpt_options: &EditPredictionExcerptOptions, + parent_abs_path: Option<&Path>, + options: &EditPredictionContextOptions, index_state: Option<&SyntaxIndexState>, ) -> Option { + let imports = if options.use_imports { + Imports::gather(&buffer, parent_abs_path) + } else { + Imports::default() + }; Self::gather_context_with_references_fn( cursor_point, buffer, - excerpt_options, + &imports, + options, index_state, references_in_excerpt, ) @@ -68,7 +97,8 @@ impl EditPredictionContext { pub fn gather_context_with_references_fn( cursor_point: Point, buffer: &BufferSnapshot, - excerpt_options: &EditPredictionExcerptOptions, + imports: &Imports, + options: &EditPredictionContextOptions, index_state: Option<&SyntaxIndexState>, get_references: impl FnOnce( &EditPredictionExcerpt, @@ -79,7 +109,7 @@ impl EditPredictionContext { let excerpt = EditPredictionExcerpt::select_from_buffer( cursor_point, buffer, - excerpt_options, + &options.excerpt, index_state, )?; let excerpt_text = excerpt.text(buffer); @@ -101,10 +131,12 @@ impl EditPredictionContext { let references = get_references(&excerpt, &excerpt_text, buffer); scored_declarations( + &options.score, &index_state, &excerpt, &excerpt_occurrences, &adjacent_occurrences, + &imports, references, cursor_offset_in_file, buffer, @@ -160,12 +192,18 @@ mod tests { EditPredictionContext::gather_context_in_background( cursor_point, buffer_snapshot, - EditPredictionExcerptOptions { - max_bytes: 60, - min_bytes: 10, - target_before_cursor_over_total_bytes: 0.5, + EditPredictionContextOptions { + use_imports: true, + excerpt: EditPredictionExcerptOptions { + max_bytes: 60, + min_bytes: 10, + target_before_cursor_over_total_bytes: 0.5, + }, + score: EditPredictionScoreOptions { + omit_excerpt_overlaps: true, + }, }, - Some(index), + Some(index.clone()), cx, ) }) diff --git a/crates/edit_prediction_context/src/imports.rs b/crates/edit_prediction_context/src/imports.rs new file mode 100644 index 0000000000000000000000000000000000000000..70f175159340ddb9a6f26f23db0c1b3c843e7b96 --- /dev/null +++ b/crates/edit_prediction_context/src/imports.rs @@ -0,0 +1,1319 @@ +use collections::HashMap; +use language::BufferSnapshot; +use language::ImportsConfig; +use language::Language; +use std::ops::Deref; +use std::path::Path; +use std::sync::Arc; +use std::{borrow::Cow, ops::Range}; +use text::OffsetRangeExt as _; +use util::RangeExt; +use util::paths::PathStyle; + +use crate::Identifier; +use crate::text_similarity::Occurrences; + +// TODO: Write documentation for extension authors. The @import capture must match before or in the +// same pattern as all all captures it contains + +// Future improvements to consider: +// +// * Distinguish absolute vs relative paths in captures. `#include "maths.h"` is relative whereas +// `#include ` is not. +// +// * Provide the name used when importing whole modules (see tests with "named_module" in the name). +// To be useful, will require parsing of identifier qualification. +// +// * Scoping for imports that aren't at the top level +// +// * Only scan a prefix of the file, when possible. This could look like having query matches that +// indicate it reached a declaration that is not allowed in the import section. +// +// * Support directly parsing to occurrences instead of storing namespaces / paths. Types should be +// generic on this, so that tests etc can still use strings. Could do similar in syntax index. +// +// * Distinguish different types of namespaces when known. E.g. "name.type" capture. Once capture +// names are more open-ended like this may make sense to build and cache a jump table (direct +// dispatch from capture index). +// +// * There are a few "Language specific:" comments on behavior that gets applied to all languages. +// Would be cleaner to be conditional on the language or otherwise configured. + +#[derive(Debug, Clone, Default)] +pub struct Imports { + pub identifier_to_imports: HashMap>, + pub wildcard_modules: Vec, +} + +#[derive(Debug, Clone)] +pub enum Import { + Direct { + module: Module, + }, + Alias { + module: Module, + external_identifier: Identifier, + }, +} + +#[derive(Debug, Clone)] +pub enum Module { + SourceExact(Arc), + SourceFuzzy(Arc), + Namespace(Namespace), +} + +impl Module { + fn empty() -> Self { + Module::Namespace(Namespace::default()) + } + + fn push_range( + &mut self, + range: &ModuleRange, + snapshot: &BufferSnapshot, + language: &Language, + parent_abs_path: Option<&Path>, + ) -> usize { + if range.is_empty() { + return 0; + } + + match range { + ModuleRange::Source(range) => { + if let Self::Namespace(namespace) = self + && namespace.0.is_empty() + { + let path = snapshot.text_for_range(range.clone()).collect::>(); + + let path = if let Some(strip_regex) = + language.config().import_path_strip_regex.as_ref() + { + strip_regex.replace_all(&path, "") + } else { + path + }; + + let path = Path::new(path.as_ref()); + if (path.starts_with(".") || path.starts_with("..")) + && let Some(parent_abs_path) = parent_abs_path + && let Ok(abs_path) = + util::paths::normalize_lexically(&parent_abs_path.join(path)) + { + *self = Self::SourceExact(abs_path.into()); + } else { + *self = Self::SourceFuzzy(path.into()); + }; + } else if matches!(self, Self::SourceExact(_)) + || matches!(self, Self::SourceFuzzy(_)) + { + log::warn!("bug in imports query: encountered multiple @source matches"); + } else { + log::warn!( + "bug in imports query: encountered both @namespace and @source match" + ); + } + } + ModuleRange::Namespace(range) => { + if let Self::Namespace(namespace) = self { + let segment = range_text(snapshot, range); + if language.config().ignored_import_segments.contains(&segment) { + return 0; + } else { + namespace.0.push(segment); + return 1; + } + } else { + log::warn!( + "bug in imports query: encountered both @namespace and @source match" + ); + } + } + } + 0 + } +} + +#[derive(Debug, Clone)] +enum ModuleRange { + Source(Range), + Namespace(Range), +} + +impl Deref for ModuleRange { + type Target = Range; + + fn deref(&self) -> &Self::Target { + match self { + ModuleRange::Source(range) => range, + ModuleRange::Namespace(range) => range, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub struct Namespace(pub Vec>); + +impl Namespace { + pub fn occurrences(&self) -> Occurrences { + Occurrences::from_identifiers(&self.0) + } +} + +impl Imports { + pub fn gather(snapshot: &BufferSnapshot, parent_abs_path: Option<&Path>) -> Self { + // Query to match different import patterns + let mut matches = snapshot + .syntax + .matches(0..snapshot.len(), &snapshot.text, |grammar| { + grammar.imports_config().map(|imports| &imports.query) + }); + + let mut detached_nodes: Vec = Vec::new(); + let mut identifier_to_imports = HashMap::default(); + let mut wildcard_modules = Vec::new(); + let mut import_range = None; + + while let Some(query_match) = matches.peek() { + let ImportsConfig { + query: _, + import_ix, + name_ix, + namespace_ix, + source_ix, + list_ix, + wildcard_ix, + alias_ix, + } = matches.grammars()[query_match.grammar_index] + .imports_config() + .unwrap(); + + let mut new_import_range = None; + let mut alias_range = None; + let mut modules = Vec::new(); + let mut content: Option<(Range, ContentKind)> = None; + for capture in query_match.captures { + let capture_range = capture.node.byte_range(); + + if capture.index == *import_ix { + new_import_range = Some(capture_range); + } else if Some(capture.index) == *namespace_ix { + modules.push(ModuleRange::Namespace(capture_range)); + } else if Some(capture.index) == *source_ix { + modules.push(ModuleRange::Source(capture_range)); + } else if Some(capture.index) == *alias_ix { + alias_range = Some(capture_range); + } else { + let mut found_content = None; + if Some(capture.index) == *name_ix { + found_content = Some((capture_range, ContentKind::Name)); + } else if Some(capture.index) == *list_ix { + found_content = Some((capture_range, ContentKind::List)); + } else if Some(capture.index) == *wildcard_ix { + found_content = Some((capture_range, ContentKind::Wildcard)); + } + if let Some((found_content_range, found_kind)) = found_content { + if let Some((_, old_kind)) = content { + let point = found_content_range.to_point(snapshot); + log::warn!( + "bug in {} imports query: unexpected multiple captures of {} and {} ({}:{}:{})", + query_match.language.name(), + old_kind.capture_name(), + found_kind.capture_name(), + snapshot + .file() + .map(|p| p.path().display(PathStyle::Posix)) + .unwrap_or_default(), + point.start.row + 1, + point.start.column + 1 + ); + } + content = Some((found_content_range, found_kind)); + } + } + } + + if let Some(new_import_range) = new_import_range { + log::trace!("starting new import {:?}", new_import_range); + Self::gather_from_import_statement( + &detached_nodes, + &snapshot, + parent_abs_path, + &mut identifier_to_imports, + &mut wildcard_modules, + ); + detached_nodes.clear(); + import_range = Some(new_import_range.clone()); + } + + if let Some((content, content_kind)) = content { + if import_range + .as_ref() + .is_some_and(|import_range| import_range.contains_inclusive(&content)) + { + detached_nodes.push(DetachedNode { + modules, + content: content.clone(), + content_kind, + alias: alias_range.unwrap_or(0..0), + language: query_match.language.clone(), + }); + } else { + log::trace!( + "filtered out match not inside import range: {content_kind:?} at {content:?}" + ); + } + } + + matches.advance(); + } + + Self::gather_from_import_statement( + &detached_nodes, + &snapshot, + parent_abs_path, + &mut identifier_to_imports, + &mut wildcard_modules, + ); + + Imports { + identifier_to_imports, + wildcard_modules, + } + } + + fn gather_from_import_statement( + detached_nodes: &[DetachedNode], + snapshot: &BufferSnapshot, + parent_abs_path: Option<&Path>, + identifier_to_imports: &mut HashMap>, + wildcard_modules: &mut Vec, + ) { + let mut trees = Vec::new(); + + for detached_node in detached_nodes { + if let Some(node) = Self::attach_node(detached_node.into(), &mut trees) { + trees.push(node); + } + log::trace!( + "Attached node to tree\n{:#?}\nAttach result:\n{:#?}", + detached_node, + trees + .iter() + .map(|tree| tree.debug(snapshot)) + .collect::>() + ); + } + + for tree in &trees { + let mut module = Module::empty(); + Self::gather_from_tree( + tree, + snapshot, + parent_abs_path, + &mut module, + identifier_to_imports, + wildcard_modules, + ); + } + } + + fn attach_node(mut node: ImportTree, trees: &mut Vec) -> Option { + let mut tree_index = 0; + while tree_index < trees.len() { + let tree = &mut trees[tree_index]; + if !node.content.is_empty() && node.content == tree.content { + // multiple matches can apply to the same name/list/wildcard. This keeps the queries + // simpler by combining info from these matches. + if tree.module.is_empty() { + tree.module = node.module; + tree.module_children = node.module_children; + } + if tree.alias.is_empty() { + tree.alias = node.alias; + } + return None; + } else if !node.module.is_empty() && node.module.contains_inclusive(&tree.range()) { + node.module_children.push(trees.remove(tree_index)); + continue; + } else if !node.content.is_empty() && node.content.contains_inclusive(&tree.content) { + node.content_children.push(trees.remove(tree_index)); + continue; + } else if !tree.content.is_empty() && tree.content.contains_inclusive(&node.content) { + if let Some(node) = Self::attach_node(node, &mut tree.content_children) { + tree.content_children.push(node); + } + return None; + } + tree_index += 1; + } + Some(node) + } + + fn gather_from_tree( + tree: &ImportTree, + snapshot: &BufferSnapshot, + parent_abs_path: Option<&Path>, + current_module: &mut Module, + identifier_to_imports: &mut HashMap>, + wildcard_modules: &mut Vec, + ) { + let mut pop_count = 0; + + if tree.module_children.is_empty() { + pop_count += + current_module.push_range(&tree.module, snapshot, &tree.language, parent_abs_path); + } else { + for child in &tree.module_children { + pop_count += Self::extend_namespace_from_tree( + child, + snapshot, + parent_abs_path, + current_module, + ); + } + }; + + if tree.content_children.is_empty() && !tree.content.is_empty() { + match tree.content_kind { + ContentKind::Name | ContentKind::List => { + if tree.alias.is_empty() { + identifier_to_imports + .entry(Identifier { + language_id: tree.language.id(), + name: range_text(snapshot, &tree.content), + }) + .or_default() + .push(Import::Direct { + module: current_module.clone(), + }); + } else { + let alias_name: Arc = range_text(snapshot, &tree.alias); + let external_name = range_text(snapshot, &tree.content); + // Language specific: skip "_" aliases for Rust + if alias_name.as_ref() != "_" { + identifier_to_imports + .entry(Identifier { + language_id: tree.language.id(), + name: alias_name, + }) + .or_default() + .push(Import::Alias { + module: current_module.clone(), + external_identifier: Identifier { + language_id: tree.language.id(), + name: external_name, + }, + }); + } + } + } + ContentKind::Wildcard => wildcard_modules.push(current_module.clone()), + } + } else { + for child in &tree.content_children { + Self::gather_from_tree( + child, + snapshot, + parent_abs_path, + current_module, + identifier_to_imports, + wildcard_modules, + ); + } + } + + if pop_count > 0 { + match current_module { + Module::SourceExact(_) | Module::SourceFuzzy(_) => { + log::warn!( + "bug in imports query: encountered both @namespace and @source match" + ); + } + Module::Namespace(namespace) => { + namespace.0.drain(namespace.0.len() - pop_count..); + } + } + } + } + + fn extend_namespace_from_tree( + tree: &ImportTree, + snapshot: &BufferSnapshot, + parent_abs_path: Option<&Path>, + module: &mut Module, + ) -> usize { + let mut pop_count = 0; + if tree.module_children.is_empty() { + pop_count += module.push_range(&tree.module, snapshot, &tree.language, parent_abs_path); + } else { + for child in &tree.module_children { + pop_count += + Self::extend_namespace_from_tree(child, snapshot, parent_abs_path, module); + } + } + if tree.content_children.is_empty() { + pop_count += module.push_range( + &ModuleRange::Namespace(tree.content.clone()), + snapshot, + &tree.language, + parent_abs_path, + ); + } else { + for child in &tree.content_children { + pop_count += + Self::extend_namespace_from_tree(child, snapshot, parent_abs_path, module); + } + } + pop_count + } +} + +fn range_text(snapshot: &BufferSnapshot, range: &Range) -> Arc { + snapshot + .text_for_range(range.clone()) + .collect::>() + .into() +} + +#[derive(Debug)] +struct DetachedNode { + modules: Vec, + content: Range, + content_kind: ContentKind, + alias: Range, + language: Arc, +} + +#[derive(Debug, Clone, Copy)] +enum ContentKind { + Name, + Wildcard, + List, +} + +impl ContentKind { + fn capture_name(&self) -> &'static str { + match self { + ContentKind::Name => "name", + ContentKind::Wildcard => "wildcard", + ContentKind::List => "list", + } + } +} + +#[derive(Debug)] +struct ImportTree { + module: ModuleRange, + /// When non-empty, provides namespace / source info which should be used instead of `module`. + module_children: Vec, + content: Range, + /// When non-empty, provides content which should be used instead of `content`. + content_children: Vec, + content_kind: ContentKind, + alias: Range, + language: Arc, +} + +impl ImportTree { + fn range(&self) -> Range { + self.module.start.min(self.content.start)..self.module.end.max(self.content.end) + } + + #[allow(dead_code)] + fn debug<'a>(&'a self, snapshot: &'a BufferSnapshot) -> ImportTreeDebug<'a> { + ImportTreeDebug { + tree: self, + snapshot, + } + } + + fn from_module_range(module: &ModuleRange, language: Arc) -> Self { + ImportTree { + module: module.clone(), + module_children: Vec::new(), + content: 0..0, + content_children: Vec::new(), + content_kind: ContentKind::Name, + alias: 0..0, + language, + } + } +} + +impl From<&DetachedNode> for ImportTree { + fn from(value: &DetachedNode) -> Self { + let module; + let module_children; + match value.modules.len() { + 0 => { + module = ModuleRange::Namespace(0..0); + module_children = Vec::new(); + } + 1 => { + module = value.modules[0].clone(); + module_children = Vec::new(); + } + _ => { + module = ModuleRange::Namespace( + value.modules.first().unwrap().start..value.modules.last().unwrap().end, + ); + module_children = value + .modules + .iter() + .map(|module| ImportTree::from_module_range(module, value.language.clone())) + .collect(); + } + } + + ImportTree { + module, + module_children, + content: value.content.clone(), + content_children: Vec::new(), + content_kind: value.content_kind, + alias: value.alias.clone(), + language: value.language.clone(), + } + } +} + +struct ImportTreeDebug<'a> { + tree: &'a ImportTree, + snapshot: &'a BufferSnapshot, +} + +impl std::fmt::Debug for ImportTreeDebug<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ImportTree") + .field("module_range", &self.tree.module) + .field("module_text", &range_text(self.snapshot, &self.tree.module)) + .field( + "module_children", + &self + .tree + .module_children + .iter() + .map(|child| child.debug(&self.snapshot)) + .collect::>(), + ) + .field("content_range", &self.tree.content) + .field( + "content_text", + &range_text(self.snapshot, &self.tree.content), + ) + .field( + "content_children", + &self + .tree + .content_children + .iter() + .map(|child| child.debug(&self.snapshot)) + .collect::>(), + ) + .field("content_kind", &self.tree.content_kind) + .field("alias_range", &self.tree.alias) + .field("alias_text", &range_text(self.snapshot, &self.tree.alias)) + .finish() + } +} + +#[cfg(test)] +mod test { + use std::path::PathBuf; + use std::sync::{Arc, LazyLock}; + + use super::*; + use collections::HashSet; + use gpui::{TestAppContext, prelude::*}; + use indoc::indoc; + use language::{ + Buffer, Language, LanguageConfig, tree_sitter_python, tree_sitter_rust, + tree_sitter_typescript, + }; + use regex::Regex; + + #[gpui::test] + fn test_rust_simple(cx: &mut TestAppContext) { + check_imports( + &RUST, + "use std::collections::HashMap;", + &[&["std", "collections", "HashMap"]], + cx, + ); + + check_imports( + &RUST, + "pub use std::collections::HashMap;", + &[&["std", "collections", "HashMap"]], + cx, + ); + + check_imports( + &RUST, + "use std::collections::{HashMap, HashSet};", + &[ + &["std", "collections", "HashMap"], + &["std", "collections", "HashSet"], + ], + cx, + ); + } + + #[gpui::test] + fn test_rust_nested(cx: &mut TestAppContext) { + check_imports( + &RUST, + "use std::{any::TypeId, collections::{HashMap, HashSet}};", + &[ + &["std", "any", "TypeId"], + &["std", "collections", "HashMap"], + &["std", "collections", "HashSet"], + ], + cx, + ); + + check_imports( + &RUST, + "use a::b::c::{d::e::F, g::h::I};", + &[ + &["a", "b", "c", "d", "e", "F"], + &["a", "b", "c", "g", "h", "I"], + ], + cx, + ); + } + + #[gpui::test] + fn test_rust_multiple_imports(cx: &mut TestAppContext) { + check_imports( + &RUST, + indoc! {" + use std::collections::HashMap; + use std::any::{TypeId, Any}; + "}, + &[ + &["std", "collections", "HashMap"], + &["std", "any", "TypeId"], + &["std", "any", "Any"], + ], + cx, + ); + + check_imports( + &RUST, + indoc! {" + use std::collections::HashSet; + + fn main() { + let unqualified = HashSet::new(); + let qualified = std::collections::HashMap::new(); + } + + use std::any::TypeId; + "}, + &[ + &["std", "collections", "HashSet"], + &["std", "any", "TypeId"], + ], + cx, + ); + } + + #[gpui::test] + fn test_rust_wildcard(cx: &mut TestAppContext) { + check_imports(&RUST, "use prelude::*;", &[&["prelude", "WILDCARD"]], cx); + + check_imports( + &RUST, + "use zed::prelude::*;", + &[&["zed", "prelude", "WILDCARD"]], + cx, + ); + + check_imports(&RUST, "use prelude::{*};", &[&["prelude", "WILDCARD"]], cx); + + check_imports( + &RUST, + "use prelude::{File, *};", + &[&["prelude", "File"], &["prelude", "WILDCARD"]], + cx, + ); + + check_imports( + &RUST, + "use zed::{App, prelude::*};", + &[&["zed", "App"], &["zed", "prelude", "WILDCARD"]], + cx, + ); + } + + #[gpui::test] + fn test_rust_alias(cx: &mut TestAppContext) { + check_imports( + &RUST, + "use std::io::Result as IoResult;", + &[&["std", "io", "Result AS IoResult"]], + cx, + ); + } + + #[gpui::test] + fn test_rust_crate_and_super(cx: &mut TestAppContext) { + check_imports(&RUST, "use crate::a::b::c;", &[&["a", "b", "c"]], cx); + check_imports(&RUST, "use super::a::b::c;", &[&["a", "b", "c"]], cx); + // TODO: Consider stripping leading "::". Not done for now because for the text similarity matching usecase this + // is fine. + check_imports(&RUST, "use ::a::b::c;", &[&["::a", "b", "c"]], cx); + } + + #[gpui::test] + fn test_typescript_imports(cx: &mut TestAppContext) { + let parent_abs_path = PathBuf::from("/home/user/project"); + + check_imports_with_file_abs_path( + Some(&parent_abs_path), + &TYPESCRIPT, + r#"import "./maths.js";"#, + &[&["SOURCE /home/user/project/maths", "WILDCARD"]], + cx, + ); + + check_imports_with_file_abs_path( + Some(&parent_abs_path), + &TYPESCRIPT, + r#"import "../maths.js";"#, + &[&["SOURCE /home/user/maths", "WILDCARD"]], + cx, + ); + + check_imports_with_file_abs_path( + Some(&parent_abs_path), + &TYPESCRIPT, + r#"import RandomNumberGenerator, { pi as π } from "./maths.js";"#, + &[ + &["SOURCE /home/user/project/maths", "RandomNumberGenerator"], + &["SOURCE /home/user/project/maths", "pi AS π"], + ], + cx, + ); + + check_imports_with_file_abs_path( + Some(&parent_abs_path), + &TYPESCRIPT, + r#"import { pi, phi, absolute } from "./maths.js";"#, + &[ + &["SOURCE /home/user/project/maths", "pi"], + &["SOURCE /home/user/project/maths", "phi"], + &["SOURCE /home/user/project/maths", "absolute"], + ], + cx, + ); + + // index.js is removed by import_path_strip_regex + check_imports_with_file_abs_path( + Some(&parent_abs_path), + &TYPESCRIPT, + r#"import { pi, phi, absolute } from "./maths/index.js";"#, + &[ + &["SOURCE /home/user/project/maths", "pi"], + &["SOURCE /home/user/project/maths", "phi"], + &["SOURCE /home/user/project/maths", "absolute"], + ], + cx, + ); + + check_imports_with_file_abs_path( + Some(&parent_abs_path), + &TYPESCRIPT, + r#"import type { SomeThing } from "./some-module.js";"#, + &[&["SOURCE /home/user/project/some-module", "SomeThing"]], + cx, + ); + + check_imports_with_file_abs_path( + Some(&parent_abs_path), + &TYPESCRIPT, + r#"import { type SomeThing, OtherThing } from "./some-module.js";"#, + &[ + &["SOURCE /home/user/project/some-module", "SomeThing"], + &["SOURCE /home/user/project/some-module", "OtherThing"], + ], + cx, + ); + + // index.js is removed by import_path_strip_regex + check_imports_with_file_abs_path( + Some(&parent_abs_path), + &TYPESCRIPT, + r#"import { type SomeThing, OtherThing } from "./some-module/index.js";"#, + &[ + &["SOURCE /home/user/project/some-module", "SomeThing"], + &["SOURCE /home/user/project/some-module", "OtherThing"], + ], + cx, + ); + + // fuzzy paths + check_imports_with_file_abs_path( + Some(&parent_abs_path), + &TYPESCRIPT, + r#"import { type SomeThing, OtherThing } from "@my-app/some-module.js";"#, + &[ + &["SOURCE FUZZY @my-app/some-module", "SomeThing"], + &["SOURCE FUZZY @my-app/some-module", "OtherThing"], + ], + cx, + ); + } + + #[gpui::test] + fn test_typescript_named_module_imports(cx: &mut TestAppContext) { + let parent_abs_path = PathBuf::from("/home/user/project"); + + // TODO: These should provide the name that the module is bound to. + // For now instead these are treated as unqualified wildcard imports. + check_imports_with_file_abs_path( + Some(&parent_abs_path), + &TYPESCRIPT, + r#"import * as math from "./maths.js";"#, + // &[&["/home/user/project/maths.js", "WILDCARD AS math"]], + &[&["SOURCE /home/user/project/maths", "WILDCARD"]], + cx, + ); + check_imports_with_file_abs_path( + Some(&parent_abs_path), + &TYPESCRIPT, + r#"import math = require("./maths");"#, + // &[&["/home/user/project/maths", "WILDCARD AS math"]], + &[&["SOURCE /home/user/project/maths", "WILDCARD"]], + cx, + ); + } + + #[gpui::test] + fn test_python_imports(cx: &mut TestAppContext) { + check_imports(&PYTHON, "from math import pi", &[&["math", "pi"]], cx); + + check_imports( + &PYTHON, + "from math import pi, sin, cos", + &[&["math", "pi"], &["math", "sin"], &["math", "cos"]], + cx, + ); + + check_imports(&PYTHON, "from math import *", &[&["math", "WILDCARD"]], cx); + + check_imports( + &PYTHON, + "from math import foo.bar.baz", + &[&["math", "foo", "bar", "baz"]], + cx, + ); + + check_imports( + &PYTHON, + "from math import pi as PI", + &[&["math", "pi AS PI"]], + cx, + ); + + check_imports( + &PYTHON, + "from serializers.json import JsonSerializer", + &[&["serializers", "json", "JsonSerializer"]], + cx, + ); + + check_imports( + &PYTHON, + "from custom.serializers import json, xml, yaml", + &[ + &["custom", "serializers", "json"], + &["custom", "serializers", "xml"], + &["custom", "serializers", "yaml"], + ], + cx, + ); + } + + #[gpui::test] + fn test_python_named_module_imports(cx: &mut TestAppContext) { + // TODO: These should provide the name that the module is bound to. + // For now instead these are treated as unqualified wildcard imports. + // + // check_imports(&PYTHON, "import math", &[&["math", "WILDCARD as math"]], cx); + // check_imports(&PYTHON, "import math as maths", &[&["math", "WILDCARD AS maths"]], cx); + // + // Something like: + // + // (import_statement + // name: [ + // (dotted_name + // (identifier)* @namespace + // (identifier) @name.module .) + // (aliased_import + // name: (dotted_name + // ((identifier) ".")* @namespace + // (identifier) @name.module .) + // alias: (identifier) @alias) + // ]) @import + + check_imports(&PYTHON, "import math", &[&["math", "WILDCARD"]], cx); + + check_imports( + &PYTHON, + "import math as maths", + &[&["math", "WILDCARD"]], + cx, + ); + + check_imports(&PYTHON, "import a.b.c", &[&["a", "b", "c", "WILDCARD"]], cx); + + check_imports( + &PYTHON, + "import a.b.c as d", + &[&["a", "b", "c", "WILDCARD"]], + cx, + ); + } + + #[gpui::test] + fn test_python_package_relative_imports(cx: &mut TestAppContext) { + // TODO: These should provide info about the dir they are relative to, to provide more + // precise resolution. Instead, fuzzy matching is used as usual. + + check_imports(&PYTHON, "from . import math", &[&["math"]], cx); + + check_imports(&PYTHON, "from .a import math", &[&["a", "math"]], cx); + + check_imports( + &PYTHON, + "from ..a.b import math", + &[&["a", "b", "math"]], + cx, + ); + + check_imports( + &PYTHON, + "from ..a.b import *", + &[&["a", "b", "WILDCARD"]], + cx, + ); + } + + #[gpui::test] + fn test_c_imports(cx: &mut TestAppContext) { + let parent_abs_path = PathBuf::from("/home/user/project"); + + // TODO: Distinguish that these are not relative to current path + check_imports_with_file_abs_path( + Some(&parent_abs_path), + &C, + r#"#include "#, + &[&["SOURCE FUZZY math.h", "WILDCARD"]], + cx, + ); + + // TODO: These should be treated as relative, but don't start with ./ or ../ + check_imports_with_file_abs_path( + Some(&parent_abs_path), + &C, + r#"#include "math.h""#, + &[&["SOURCE FUZZY math.h", "WILDCARD"]], + cx, + ); + } + + #[gpui::test] + fn test_cpp_imports(cx: &mut TestAppContext) { + let parent_abs_path = PathBuf::from("/home/user/project"); + + // TODO: Distinguish that these are not relative to current path + check_imports_with_file_abs_path( + Some(&parent_abs_path), + &CPP, + r#"#include "#, + &[&["SOURCE FUZZY math.h", "WILDCARD"]], + cx, + ); + + // TODO: These should be treated as relative, but don't start with ./ or ../ + check_imports_with_file_abs_path( + Some(&parent_abs_path), + &CPP, + r#"#include "math.h""#, + &[&["SOURCE FUZZY math.h", "WILDCARD"]], + cx, + ); + } + + #[gpui::test] + fn test_go_imports(cx: &mut TestAppContext) { + check_imports( + &GO, + r#"import . "lib/math""#, + &[&["lib/math", "WILDCARD"]], + cx, + ); + + // not included, these are only for side-effects + check_imports(&GO, r#"import _ "lib/math""#, &[], cx); + } + + #[gpui::test] + fn test_go_named_module_imports(cx: &mut TestAppContext) { + // TODO: These should provide the name that the module is bound to. + // For now instead these are treated as unqualified wildcard imports. + + check_imports( + &GO, + r#"import "lib/math""#, + &[&["lib/math", "WILDCARD"]], + cx, + ); + check_imports( + &GO, + r#"import m "lib/math""#, + &[&["lib/math", "WILDCARD"]], + cx, + ); + } + + #[track_caller] + fn check_imports( + language: &Arc, + source: &str, + expected: &[&[&str]], + cx: &mut TestAppContext, + ) { + check_imports_with_file_abs_path(None, language, source, expected, cx); + } + + #[track_caller] + fn check_imports_with_file_abs_path( + parent_abs_path: Option<&Path>, + language: &Arc, + source: &str, + expected: &[&[&str]], + cx: &mut TestAppContext, + ) { + let buffer = cx.new(|cx| { + let mut buffer = Buffer::local(source, cx); + buffer.set_language(Some(language.clone()), cx); + buffer + }); + cx.run_until_parked(); + + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot()); + + let imports = Imports::gather(&snapshot, parent_abs_path); + let mut actual_symbols = imports + .identifier_to_imports + .iter() + .flat_map(|(identifier, imports)| { + imports + .iter() + .map(|import| import.to_identifier_parts(identifier.name.as_ref())) + }) + .chain( + imports + .wildcard_modules + .iter() + .map(|module| module.to_identifier_parts("WILDCARD")), + ) + .collect::>(); + let mut expected_symbols = expected + .iter() + .map(|expected| expected.iter().map(|s| s.to_string()).collect::>()) + .collect::>(); + actual_symbols.sort(); + expected_symbols.sort(); + if actual_symbols != expected_symbols { + let top_layer = snapshot.syntax_layers().next().unwrap(); + panic!( + "Expected imports: {:?}\n\ + Actual imports: {:?}\n\ + Tree:\n{}", + expected_symbols, + actual_symbols, + tree_to_string(&top_layer.node()), + ); + } + } + + fn tree_to_string(node: &tree_sitter::Node) -> String { + let mut cursor = node.walk(); + let mut result = String::new(); + let mut depth = 0; + 'outer: loop { + result.push_str(&" ".repeat(depth)); + if let Some(field_name) = cursor.field_name() { + result.push_str(field_name); + result.push_str(": "); + } + if cursor.node().is_named() { + result.push_str(cursor.node().kind()); + } else { + result.push('"'); + result.push_str(cursor.node().kind()); + result.push('"'); + } + result.push('\n'); + + if cursor.goto_first_child() { + depth += 1; + continue; + } + if cursor.goto_next_sibling() { + continue; + } + while cursor.goto_parent() { + depth -= 1; + if cursor.goto_next_sibling() { + continue 'outer; + } + } + break; + } + result + } + + static RUST: LazyLock> = LazyLock::new(|| { + Arc::new( + Language::new( + LanguageConfig { + name: "Rust".into(), + ignored_import_segments: HashSet::from_iter(["crate".into(), "super".into()]), + import_path_strip_regex: Some(Regex::new("/(lib|mod)\\.rs$").unwrap()), + ..Default::default() + }, + Some(tree_sitter_rust::LANGUAGE.into()), + ) + .with_imports_query(include_str!("../../languages/src/rust/imports.scm")) + .unwrap(), + ) + }); + + static TYPESCRIPT: LazyLock> = LazyLock::new(|| { + Arc::new( + Language::new( + LanguageConfig { + name: "TypeScript".into(), + import_path_strip_regex: Some(Regex::new("(?:/index)?\\.[jt]s$").unwrap()), + ..Default::default() + }, + Some(tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into()), + ) + .with_imports_query(include_str!("../../languages/src/typescript/imports.scm")) + .unwrap(), + ) + }); + + static PYTHON: LazyLock> = LazyLock::new(|| { + Arc::new( + Language::new( + LanguageConfig { + name: "Python".into(), + import_path_strip_regex: Some(Regex::new("/__init__\\.py$").unwrap()), + ..Default::default() + }, + Some(tree_sitter_python::LANGUAGE.into()), + ) + .with_imports_query(include_str!("../../languages/src/python/imports.scm")) + .unwrap(), + ) + }); + + // TODO: Ideally should use actual language configurations + static C: LazyLock> = LazyLock::new(|| { + Arc::new( + Language::new( + LanguageConfig { + name: "C".into(), + import_path_strip_regex: Some(Regex::new("^<|>$").unwrap()), + ..Default::default() + }, + Some(tree_sitter_c::LANGUAGE.into()), + ) + .with_imports_query(include_str!("../../languages/src/c/imports.scm")) + .unwrap(), + ) + }); + + static CPP: LazyLock> = LazyLock::new(|| { + Arc::new( + Language::new( + LanguageConfig { + name: "C++".into(), + import_path_strip_regex: Some(Regex::new("^<|>$").unwrap()), + ..Default::default() + }, + Some(tree_sitter_cpp::LANGUAGE.into()), + ) + .with_imports_query(include_str!("../../languages/src/cpp/imports.scm")) + .unwrap(), + ) + }); + + static GO: LazyLock> = LazyLock::new(|| { + Arc::new( + Language::new( + LanguageConfig { + name: "Go".into(), + ..Default::default() + }, + Some(tree_sitter_go::LANGUAGE.into()), + ) + .with_imports_query(include_str!("../../languages/src/go/imports.scm")) + .unwrap(), + ) + }); + + impl Import { + fn to_identifier_parts(&self, identifier: &str) -> Vec { + match self { + Import::Direct { module } => module.to_identifier_parts(identifier), + Import::Alias { + module, + external_identifier: external_name, + } => { + module.to_identifier_parts(&format!("{} AS {}", external_name.name, identifier)) + } + } + } + } + + impl Module { + fn to_identifier_parts(&self, identifier: &str) -> Vec { + match self { + Self::Namespace(namespace) => namespace.to_identifier_parts(identifier), + Self::SourceExact(path) => { + vec![ + format!("SOURCE {}", path.display().to_string().replace("\\", "/")), + identifier.to_string(), + ] + } + Self::SourceFuzzy(path) => { + vec![ + format!( + "SOURCE FUZZY {}", + path.display().to_string().replace("\\", "/") + ), + identifier.to_string(), + ] + } + } + } + } + + impl Namespace { + fn to_identifier_parts(&self, identifier: &str) -> Vec { + self.0 + .iter() + .map(|chunk| chunk.to_string()) + .chain(std::iter::once(identifier.to_string())) + .collect::>() + } + } +} diff --git a/crates/edit_prediction_context/src/syntax_index.rs b/crates/edit_prediction_context/src/syntax_index.rs index d2763a6cfdf4d65c992fee2ad10d6e15e9387530..e2728ebfc029c7c1b74a35f2e6f5a79003a9a77e 100644 --- a/crates/edit_prediction_context/src/syntax_index.rs +++ b/crates/edit_prediction_context/src/syntax_index.rs @@ -5,6 +5,7 @@ use futures::lock::Mutex; use futures::{FutureExt as _, StreamExt, future}; use gpui::{App, AppContext as _, AsyncApp, Context, Entity, Task, WeakEntity}; use itertools::Itertools; + use language::{Buffer, BufferEvent}; use postage::stream::Stream as _; use project::buffer_store::{BufferStore, BufferStoreEvent}; @@ -17,6 +18,7 @@ use std::sync::Arc; use text::BufferId; use util::{RangeExt as _, debug_panic, some_or_debug_panic}; +use crate::CachedDeclarationPath; use crate::declaration::{ BufferDeclaration, Declaration, DeclarationId, FileDeclaration, Identifier, }; @@ -28,6 +30,8 @@ use crate::outline::declarations_in_buffer; // `buffer_declarations_containing_range` assumes that the index is always immediately up to date. // // * Add a per language configuration for skipping indexing. +// +// * Handle tsx / ts / js referencing each-other // Potential future improvements: // @@ -61,6 +65,7 @@ pub struct SyntaxIndex { state: Arc>, project: WeakEntity, initial_file_indexing_done_rx: postage::watch::Receiver, + _file_indexing_task: Option>, } pub struct SyntaxIndexState { @@ -70,7 +75,6 @@ pub struct SyntaxIndexState { buffers: HashMap, dirty_files: HashMap, dirty_files_tx: mpsc::Sender<()>, - _file_indexing_task: Option>, } #[derive(Debug, Default)] @@ -102,12 +106,12 @@ impl SyntaxIndex { buffers: HashMap::default(), dirty_files: HashMap::default(), dirty_files_tx, - _file_indexing_task: None, }; - let this = Self { + let mut this = Self { project: project.downgrade(), state: Arc::new(Mutex::new(initial_state)), initial_file_indexing_done_rx, + _file_indexing_task: None, }; let worktree_store = project.read(cx).worktree_store(); @@ -116,75 +120,77 @@ impl SyntaxIndex { .worktrees() .map(|w| w.read(cx).snapshot()) .collect::>(); - if !initial_worktree_snapshots.is_empty() { - this.state.try_lock().unwrap()._file_indexing_task = - Some(cx.spawn(async move |this, cx| { - let snapshots_file_count = initial_worktree_snapshots - .iter() - .map(|worktree| worktree.file_count()) - .sum::(); - let chunk_size = snapshots_file_count.div_ceil(file_indexing_parallelism); - let chunk_count = snapshots_file_count.div_ceil(chunk_size); - let file_chunks = initial_worktree_snapshots - .iter() - .flat_map(|worktree| { - let worktree_id = worktree.id(); - worktree.files(false, 0).map(move |entry| { - ( - entry.id, - ProjectPath { - worktree_id, - path: entry.path.clone(), - }, - ) - }) + this._file_indexing_task = Some(cx.spawn(async move |this, cx| { + let snapshots_file_count = initial_worktree_snapshots + .iter() + .map(|worktree| worktree.file_count()) + .sum::(); + if snapshots_file_count > 0 { + let chunk_size = snapshots_file_count.div_ceil(file_indexing_parallelism); + let chunk_count = snapshots_file_count.div_ceil(chunk_size); + let file_chunks = initial_worktree_snapshots + .iter() + .flat_map(|worktree| { + let worktree_id = worktree.id(); + worktree.files(false, 0).map(move |entry| { + ( + entry.id, + ProjectPath { + worktree_id, + path: entry.path.clone(), + }, + ) }) - .chunks(chunk_size); - - let mut tasks = Vec::with_capacity(chunk_count); - for chunk in file_chunks.into_iter() { - tasks.push(Self::update_dirty_files( - &this, - chunk.into_iter().collect(), - cx.clone(), - )); - } - futures::future::join_all(tasks).await; - - log::info!("Finished initial file indexing"); - *initial_file_indexing_done_tx.borrow_mut() = true; - - let Ok(state) = this.read_with(cx, |this, _cx| this.state.clone()) else { - return; - }; - while dirty_files_rx.next().await.is_some() { - let mut state = state.lock().await; - let was_underused = state.dirty_files.capacity() > 255 - && state.dirty_files.len() * 8 < state.dirty_files.capacity(); - let dirty_files = state.dirty_files.drain().collect::>(); - if was_underused { - state.dirty_files.shrink_to_fit(); - } - drop(state); - if dirty_files.is_empty() { - continue; - } + }) + .chunks(chunk_size); + + let mut tasks = Vec::with_capacity(chunk_count); + for chunk in file_chunks.into_iter() { + tasks.push(Self::update_dirty_files( + &this, + chunk.into_iter().collect(), + cx.clone(), + )); + } + futures::future::join_all(tasks).await; + log::info!("Finished initial file indexing"); + } - let chunk_size = dirty_files.len().div_ceil(file_indexing_parallelism); - let chunk_count = dirty_files.len().div_ceil(chunk_size); - let mut tasks = Vec::with_capacity(chunk_count); - let chunks = dirty_files.into_iter().chunks(chunk_size); - for chunk in chunks.into_iter() { - tasks.push(Self::update_dirty_files( - &this, - chunk.into_iter().collect(), - cx.clone(), - )); - } - futures::future::join_all(tasks).await; - } - })); - } + *initial_file_indexing_done_tx.borrow_mut() = true; + + let Ok(state) = this.read_with(cx, |this, _cx| Arc::downgrade(&this.state)) else { + return; + }; + while dirty_files_rx.next().await.is_some() { + let Some(state) = state.upgrade() else { + return; + }; + let mut state = state.lock().await; + let was_underused = state.dirty_files.capacity() > 255 + && state.dirty_files.len() * 8 < state.dirty_files.capacity(); + let dirty_files = state.dirty_files.drain().collect::>(); + if was_underused { + state.dirty_files.shrink_to_fit(); + } + drop(state); + if dirty_files.is_empty() { + continue; + } + + let chunk_size = dirty_files.len().div_ceil(file_indexing_parallelism); + let chunk_count = dirty_files.len().div_ceil(chunk_size); + let mut tasks = Vec::with_capacity(chunk_count); + let chunks = dirty_files.into_iter().chunks(chunk_size); + for chunk in chunks.into_iter() { + tasks.push(Self::update_dirty_files( + &this, + chunk.into_iter().collect(), + cx.clone(), + )); + } + futures::future::join_all(tasks).await; + } + })); cx.subscribe(&worktree_store, Self::handle_worktree_store_event) .detach(); @@ -364,7 +370,9 @@ impl SyntaxIndex { cx: &mut Context, ) { match event { - BufferEvent::Edited => self.update_buffer(buffer, cx), + BufferEvent::Edited | + // paths are cached and so should be updated + BufferEvent::FileHandleChanged => self.update_buffer(buffer, cx), _ => {} } } @@ -375,8 +383,16 @@ impl SyntaxIndex { return; } - let Some(project_entry_id) = - project::File::from_dyn(buffer.file()).and_then(|f| f.project_entry_id(cx)) + let Some((project_entry_id, cached_path)) = project::File::from_dyn(buffer.file()) + .and_then(|f| { + let project_entry_id = f.project_entry_id()?; + let cached_path = CachedDeclarationPath::new( + f.worktree.read(cx).abs_path(), + &f.path, + buffer.language(), + ); + Some((project_entry_id, cached_path)) + }) else { return; }; @@ -440,6 +456,7 @@ impl SyntaxIndex { buffer_id, declaration, project_entry_id, + cached_path: cached_path.clone(), }); new_ids.push(declaration_id); @@ -507,13 +524,14 @@ impl SyntaxIndex { let snapshot_task = worktree.update(cx, |worktree, cx| { let load_task = worktree.load_file(&project_path.path, cx); + let worktree_abs_path = worktree.abs_path(); cx.spawn(async move |_this, cx| { let loaded_file = load_task.await?; let language = language.await?; let buffer = cx.new(|cx| { let mut buffer = Buffer::local(loaded_file.text, cx); - buffer.set_language(Some(language), cx); + buffer.set_language(Some(language.clone()), cx); buffer })?; @@ -522,14 +540,22 @@ impl SyntaxIndex { parse_status.changed().await?; } - buffer.read_with(cx, |buffer, _cx| buffer.snapshot()) + let cached_path = CachedDeclarationPath::new( + worktree_abs_path, + &project_path.path, + Some(&language), + ); + + let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?; + + anyhow::Ok((snapshot, cached_path)) }) }); let state = Arc::downgrade(&self.state); cx.background_spawn(async move { // TODO: How to handle errors? - let Ok(snapshot) = snapshot_task.await else { + let Ok((snapshot, cached_path)) = snapshot_task.await else { return; }; let rope = snapshot.as_rope(); @@ -567,6 +593,7 @@ impl SyntaxIndex { let declaration_id = state.declarations.insert(Declaration::File { project_entry_id: entry_id, declaration, + cached_path: cached_path.clone(), }); new_ids.push(declaration_id); @@ -921,6 +948,7 @@ mod tests { if let Declaration::File { declaration, project_entry_id: file, + .. } = declaration { assert_eq!( diff --git a/crates/edit_prediction_context/src/text_similarity.rs b/crates/edit_prediction_context/src/text_similarity.rs index 99d8fb4dd191bbec1b8c695f274a0024c6cb32ae..308a9570206084fc223c72f2e1c49109ea157714 100644 --- a/crates/edit_prediction_context/src/text_similarity.rs +++ b/crates/edit_prediction_context/src/text_similarity.rs @@ -1,9 +1,12 @@ use hashbrown::HashTable; use regex::Regex; use std::{ + borrow::Cow, hash::{Hash, Hasher as _}, + path::Path, sync::LazyLock, }; +use util::rel_path::RelPath; use crate::reference::Reference; @@ -45,19 +48,34 @@ impl Occurrences { ) } - pub fn from_identifiers<'a>(identifiers: impl IntoIterator) -> Self { + pub fn from_identifiers(identifiers: impl IntoIterator>) -> Self { let mut this = Self::default(); // TODO: Score matches that match case higher? // // TODO: Also include unsplit identifier? for identifier in identifiers { - for identifier_part in split_identifier(identifier) { + for identifier_part in split_identifier(identifier.as_ref()) { this.add_hash(fx_hash(&identifier_part.to_lowercase())); } } this } + pub fn from_worktree_path(worktree_name: Option>, rel_path: &RelPath) -> Self { + if let Some(worktree_name) = worktree_name { + Self::from_identifiers( + std::iter::once(worktree_name) + .chain(iter_path_without_extension(rel_path.as_std_path())), + ) + } else { + Self::from_path(rel_path.as_std_path()) + } + } + + pub fn from_path(path: &Path) -> Self { + Self::from_identifiers(iter_path_without_extension(path)) + } + fn add_hash(&mut self, hash: u64) { self.table .entry( @@ -82,6 +100,15 @@ impl Occurrences { } } +fn iter_path_without_extension(path: &Path) -> impl Iterator> { + let last_component: Option> = path.file_stem().map(|stem| stem.to_string_lossy()); + let mut path_components = path.components(); + path_components.next_back(); + path_components + .map(|component| component.as_os_str().to_string_lossy()) + .chain(last_component) +} + pub fn fx_hash(data: &T) -> u64 { let mut hasher = collections::FxHasher::default(); data.hash(&mut hasher); @@ -269,4 +296,19 @@ mod test { // the smaller set, 10. assert_eq!(weighted_overlap_coefficient(&set_a, &set_b), 7.0 / 10.0); } + + #[test] + fn test_iter_path_without_extension() { + let mut iter = iter_path_without_extension(Path::new("")); + assert_eq!(iter.next(), None); + + let iter = iter_path_without_extension(Path::new("foo")); + assert_eq!(iter.collect::>(), ["foo"]); + + let iter = iter_path_without_extension(Path::new("foo/bar.txt")); + assert_eq!(iter.collect::>(), ["foo", "bar"]); + + let iter = iter_path_without_extension(Path::new("foo/bar/baz.txt")); + assert_eq!(iter.collect::>(), ["foo", "bar", "baz"]); + } } diff --git a/crates/editor/src/editor.rs b/crates/editor/src/editor.rs index 0f3b4a928d02f4292af548fc4c08b5751406a27b..2837f1f564f6ef188595371c0301b7fd7bcf6019 100644 --- a/crates/editor/src/editor.rs +++ b/crates/editor/src/editor.rs @@ -5343,7 +5343,7 @@ impl Editor { let buffer_worktree = project.worktree_for_id(buffer_file.worktree_id(cx), cx)?; let worktree_entry = buffer_worktree .read(cx) - .entry_for_id(buffer_file.project_entry_id(cx)?)?; + .entry_for_id(buffer_file.project_entry_id()?)?; if worktree_entry.is_ignored { return None; } diff --git a/crates/language/src/language.rs b/crates/language/src/language.rs index 5ba5ffea9b1620c6b66fba7dfbe20cc8fe00ff1b..c16e90bd0f6c02fe49e2845ab24f8d767b32d82b 100644 --- a/crates/language/src/language.rs +++ b/crates/language/src/language.rs @@ -777,6 +777,15 @@ pub struct LanguageConfig { /// A list of preferred debuggers for this language. #[serde(default)] pub debuggers: IndexSet, + /// A list of import namespace segments that aren't expected to appear in file paths. For + /// example, "super" and "crate" in Rust. + #[serde(default)] + pub ignored_import_segments: HashSet>, + /// Regular expression that matches substrings to omit from import paths, to make the paths more + /// similar to how they are specified when imported. For example, "/mod\.rs$" or "/__init__\.py$". + #[serde(default, deserialize_with = "deserialize_regex")] + #[schemars(schema_with = "regex_json_schema")] + pub import_path_strip_regex: Option, } #[derive(Clone, Debug, Deserialize, Default, JsonSchema)] @@ -973,6 +982,8 @@ impl Default for LanguageConfig { completion_query_characters: Default::default(), linked_edit_characters: Default::default(), debuggers: Default::default(), + ignored_import_segments: Default::default(), + import_path_strip_regex: None, } } } @@ -1162,6 +1173,7 @@ pub struct Grammar { pub(crate) injection_config: Option, pub(crate) override_config: Option, pub(crate) debug_variables_config: Option, + pub(crate) imports_config: Option, pub(crate) highlight_map: Mutex, } @@ -1314,6 +1326,17 @@ pub struct DebugVariablesConfig { pub objects_by_capture_ix: Vec<(u32, DebuggerTextObject)>, } +pub struct ImportsConfig { + pub query: Query, + pub import_ix: u32, + pub name_ix: Option, + pub namespace_ix: Option, + pub source_ix: Option, + pub list_ix: Option, + pub wildcard_ix: Option, + pub alias_ix: Option, +} + impl Language { pub fn new(config: LanguageConfig, ts_language: Option) -> Self { Self::new_with_id(LanguageId::new(), config, ts_language) @@ -1346,6 +1369,7 @@ impl Language { runnable_config: None, error_query: Query::new(&ts_language, "(ERROR) @error").ok(), debug_variables_config: None, + imports_config: None, ts_language, highlight_map: Default::default(), }) @@ -1427,6 +1451,11 @@ impl Language { .with_debug_variables_query(query.as_ref()) .context("Error loading debug variables query")?; } + if let Some(query) = queries.imports { + self = self + .with_imports_query(query.as_ref()) + .context("Error loading imports query")?; + } Ok(self) } @@ -1595,6 +1624,45 @@ impl Language { Ok(self) } + pub fn with_imports_query(mut self, source: &str) -> Result { + let query = Query::new(&self.expect_grammar()?.ts_language, source)?; + + let mut import_ix = 0; + let mut name_ix = None; + let mut namespace_ix = None; + let mut source_ix = None; + let mut list_ix = None; + let mut wildcard_ix = None; + let mut alias_ix = None; + if populate_capture_indices( + &query, + &self.config.name, + "imports", + &[], + &mut [ + Capture::Required("import", &mut import_ix), + Capture::Optional("name", &mut name_ix), + Capture::Optional("namespace", &mut namespace_ix), + Capture::Optional("source", &mut source_ix), + Capture::Optional("list", &mut list_ix), + Capture::Optional("wildcard", &mut wildcard_ix), + Capture::Optional("alias", &mut alias_ix), + ], + ) { + self.grammar_mut()?.imports_config = Some(ImportsConfig { + query, + import_ix, + name_ix, + namespace_ix, + source_ix, + list_ix, + wildcard_ix, + alias_ix, + }); + } + return Ok(self); + } + pub fn with_brackets_query(mut self, source: &str) -> Result { let query = Query::new(&self.expect_grammar()?.ts_language, source)?; let mut open_capture_ix = 0; @@ -2149,6 +2217,10 @@ impl Grammar { pub fn debug_variables_config(&self) -> Option<&DebugVariablesConfig> { self.debug_variables_config.as_ref() } + + pub fn imports_config(&self) -> Option<&ImportsConfig> { + self.imports_config.as_ref() + } } impl CodeLabel { diff --git a/crates/language/src/language_registry.rs b/crates/language/src/language_registry.rs index 1e44660b891f62c37587fcc2d4bf83b040849af6..022eb89e6d2b378b8c4305c81887060d776bb411 100644 --- a/crates/language/src/language_registry.rs +++ b/crates/language/src/language_registry.rs @@ -229,6 +229,7 @@ pub const QUERY_FILENAME_PREFIXES: &[( ("runnables", |q| &mut q.runnables), ("debugger", |q| &mut q.debugger), ("textobjects", |q| &mut q.text_objects), + ("imports", |q| &mut q.imports), ]; /// Tree-sitter language queries for a given language. @@ -245,6 +246,7 @@ pub struct LanguageQueries { pub runnables: Option>, pub text_objects: Option>, pub debugger: Option>, + pub imports: Option>, } #[derive(Clone, Default)] diff --git a/crates/languages/src/c/config.toml b/crates/languages/src/c/config.toml index 74290fd9e2b31db93bb62187ab707110c818fc44..76a27ccc81911bcf25c7da3efef191214eab7b00 100644 --- a/crates/languages/src/c/config.toml +++ b/crates/languages/src/c/config.toml @@ -17,3 +17,4 @@ brackets = [ ] debuggers = ["CodeLLDB", "GDB"] documentation_comment = { start = "/*", prefix = "* ", end = "*/", tab_size = 1 } +import_path_strip_regex = "^<|>$" diff --git a/crates/languages/src/c/imports.scm b/crates/languages/src/c/imports.scm new file mode 100644 index 0000000000000000000000000000000000000000..c3c2c9e68c4503d323d039f9c042d9501b5e4126 --- /dev/null +++ b/crates/languages/src/c/imports.scm @@ -0,0 +1,7 @@ +(preproc_include + path: [ + ( + (system_lib_string) @source @wildcard + (#strip! @source "[<>]")) + (string_literal (string_content) @source @wildcard) + ]) @import diff --git a/crates/languages/src/cpp/config.toml b/crates/languages/src/cpp/config.toml index 7e24415f9d44c75cfe18065bbe264f0da0f561de..4d3c0a0a38664f4dd584a0ce3f3544662b19bbae 100644 --- a/crates/languages/src/cpp/config.toml +++ b/crates/languages/src/cpp/config.toml @@ -17,3 +17,4 @@ brackets = [ ] debuggers = ["CodeLLDB", "GDB"] documentation_comment = { start = "/*", prefix = "* ", end = "*/", tab_size = 1 } +import_path_strip_regex = "^<|>$" diff --git a/crates/languages/src/cpp/imports.scm b/crates/languages/src/cpp/imports.scm new file mode 100644 index 0000000000000000000000000000000000000000..a4ef817a80dbcd44336bdd8cd681587662aad435 --- /dev/null +++ b/crates/languages/src/cpp/imports.scm @@ -0,0 +1,5 @@ +(preproc_include + path: [ + ((system_lib_string) @source @wildcard) + (string_literal (string_content) @source @wildcard) + ]) @import diff --git a/crates/languages/src/go/imports.scm b/crates/languages/src/go/imports.scm new file mode 100644 index 0000000000000000000000000000000000000000..7f0ff2d46e6a271d4258d23f46cc942830e2c6f9 --- /dev/null +++ b/crates/languages/src/go/imports.scm @@ -0,0 +1,14 @@ +(import_spec + name: [ + (dot) + (package_identifier) + ] + path: (interpreted_string_literal + (interpreted_string_literal_content) @namespace) +) @wildcard @import + +(import_spec + !name + path: (interpreted_string_literal + (interpreted_string_literal_content) @namespace) +) @wildcard @import diff --git a/crates/languages/src/javascript/config.toml b/crates/languages/src/javascript/config.toml index 3bac37aa13ed34c18d1fb8e4f70e0905938e5213..265f362ce4b655371471649c03c5a4a201da320c 100644 --- a/crates/languages/src/javascript/config.toml +++ b/crates/languages/src/javascript/config.toml @@ -23,6 +23,7 @@ tab_size = 2 scope_opt_in_language_servers = ["tailwindcss-language-server", "emmet-language-server"] prettier_parser_name = "babel" debuggers = ["JavaScript"] +import_path_strip_regex = "(?:/index)?\\.[jt]s$" [jsx_tag_auto_close] open_tag_node_name = "jsx_opening_element" diff --git a/crates/languages/src/javascript/imports.scm b/crates/languages/src/javascript/imports.scm new file mode 100644 index 0000000000000000000000000000000000000000..e26b97aeef9cb62395e7030f3173208d79187bd6 --- /dev/null +++ b/crates/languages/src/javascript/imports.scm @@ -0,0 +1,14 @@ +(import_statement + import_clause: (import_clause + [ + (identifier) @name + (named_imports + (import_specifier + name: (_) @name + alias: (_)? @alias)) + ]) + source: (string (string_fragment) @source)) @import + +(import_statement + !import_clause + source: (string (string_fragment) @source @wildcard)) @import diff --git a/crates/languages/src/python/config.toml b/crates/languages/src/python/config.toml index 3e8b9b550af33fd9594dd14eda12fb81e220d7b9..c58a54fc1cae78cfb3722e74008fe42c7a883851 100644 --- a/crates/languages/src/python/config.toml +++ b/crates/languages/src/python/config.toml @@ -35,3 +35,4 @@ decrease_indent_patterns = [ { pattern = "^\\s*except\\b.*:\\s*(#.*)?", valid_after = ["try", "except"] }, { pattern = "^\\s*finally\\b.*:\\s*(#.*)?", valid_after = ["try", "except", "else"] }, ] +import_path_strip_regex = "/__init__\\.py$" diff --git a/crates/languages/src/python/imports.scm b/crates/languages/src/python/imports.scm new file mode 100644 index 0000000000000000000000000000000000000000..7a1e2b225b9e310098f316c29fe6b1a27634bf12 --- /dev/null +++ b/crates/languages/src/python/imports.scm @@ -0,0 +1,32 @@ +(import_statement + name: [ + (dotted_name + ((identifier) @namespace ".")* + (identifier) @namespace .) + (aliased_import + name: (dotted_name + ((identifier) @namespace ".")* + (identifier) @namespace .)) + ]) @wildcard @import + +(import_from_statement + module_name: [ + (dotted_name + ((identifier) @namespace ".")* + (identifier) @namespace .) + (relative_import + (dotted_name + ((identifier) @namespace ".")* + (identifier) @namespace .)?) + ] + (wildcard_import)? @wildcard + name: [ + (dotted_name + ((identifier) @namespace ".")* + (identifier) @name .) + (aliased_import + name: (dotted_name + ((identifier) @namespace ".")* + (identifier) @name .) + alias: (identifier) @alias) + ]?) @import diff --git a/crates/languages/src/rust/config.toml b/crates/languages/src/rust/config.toml index fe8b4ffdcba4f8b7949b6fe9187d16c8504d6688..826a219e9868a3f76a063efe8c91cec0be14c2da 100644 --- a/crates/languages/src/rust/config.toml +++ b/crates/languages/src/rust/config.toml @@ -17,3 +17,5 @@ brackets = [ collapsed_placeholder = " /* ... */ " debuggers = ["CodeLLDB", "GDB"] documentation_comment = { start = "/*", prefix = "* ", end = "*/", tab_size = 1 } +ignored_import_segments = ["crate", "super"] +import_path_strip_regex = "/(lib|mod)\\.rs$" diff --git a/crates/languages/src/rust/imports.scm b/crates/languages/src/rust/imports.scm new file mode 100644 index 0000000000000000000000000000000000000000..3ce6a4f073506dd4d27320a7fd5bb547927f9c1a --- /dev/null +++ b/crates/languages/src/rust/imports.scm @@ -0,0 +1,27 @@ +(use_declaration) @import + +(scoped_use_list + path: (_) @namespace + list: (_) @list) + +(scoped_identifier + path: (_) @namespace + name: (identifier) @name) + +(use_list (identifier) @name) + +(use_declaration (identifier) @name) + +(use_as_clause + path: (scoped_identifier + path: (_) @namespace + name: (_) @name) + alias: (_) @alias) + +(use_as_clause + path: (identifier) @name + alias: (_) @alias) + +(use_wildcard + (_)? @namespace + "*" @wildcard) diff --git a/crates/languages/src/tsx/imports.scm b/crates/languages/src/tsx/imports.scm new file mode 100644 index 0000000000000000000000000000000000000000..e26b97aeef9cb62395e7030f3173208d79187bd6 --- /dev/null +++ b/crates/languages/src/tsx/imports.scm @@ -0,0 +1,14 @@ +(import_statement + import_clause: (import_clause + [ + (identifier) @name + (named_imports + (import_specifier + name: (_) @name + alias: (_)? @alias)) + ]) + source: (string (string_fragment) @source)) @import + +(import_statement + !import_clause + source: (string (string_fragment) @source @wildcard)) @import diff --git a/crates/languages/src/typescript/config.toml b/crates/languages/src/typescript/config.toml index fe56e496ec717895e72f37dda9146fbb30b50e88..67656e6a538da6c8860e9ab1b08fd6e6ee9cabbd 100644 --- a/crates/languages/src/typescript/config.toml +++ b/crates/languages/src/typescript/config.toml @@ -22,6 +22,7 @@ prettier_parser_name = "typescript" tab_size = 2 debuggers = ["JavaScript"] scope_opt_in_language_servers = ["tailwindcss-language-server"] +import_path_strip_regex = "(?:/index)?\\.[jt]s$" [overrides.string] completion_query_characters = ["-", "."] diff --git a/crates/languages/src/typescript/imports.scm b/crates/languages/src/typescript/imports.scm new file mode 100644 index 0000000000000000000000000000000000000000..68ca25b2c15b7e312edbc3eeb9b2f0e493ca2d6f --- /dev/null +++ b/crates/languages/src/typescript/imports.scm @@ -0,0 +1,20 @@ +(import_statement + import_clause: (import_clause + [ + (identifier) @name + (named_imports + (import_specifier + name: (_) @name + alias: (_)? @alias)) + (namespace_import) @wildcard + ]) + source: (string (string_fragment) @source)) @import + +(import_statement + !source + import_clause: (import_require_clause + source: (string (string_fragment) @source))) @wildcard @import + +(import_statement + !import_clause + source: (string (string_fragment) @source)) @wildcard @import diff --git a/crates/outline_panel/src/outline_panel.rs b/crates/outline_panel/src/outline_panel.rs index a5307ee0d33b5ee2b24b98f9377b3b5d7ae57fd4..6dcc572e6c022112f886b9c192a65064040cf1af 100644 --- a/crates/outline_panel/src/outline_panel.rs +++ b/crates/outline_panel/src/outline_panel.rs @@ -2668,7 +2668,7 @@ impl OutlinePanel { |mut buffer_excerpts, (excerpt_id, buffer_snapshot, excerpt_range)| { let buffer_id = buffer_snapshot.remote_id(); let file = File::from_dyn(buffer_snapshot.file()); - let entry_id = file.and_then(|file| file.project_entry_id(cx)); + let entry_id = file.and_then(|file| file.project_entry_id()); let worktree = file.map(|file| file.worktree.read(cx).snapshot()); let is_new = new_entries.contains(&excerpt_id) || !outline_panel.excerpts.contains_key(&buffer_id); diff --git a/crates/project/src/project.rs b/crates/project/src/project.rs index bc7e8ad89fd01468f5e6009dda45632ab738a07c..2b6c9bfe6c45bfff8b17f05ba115923b41efc6ec 100644 --- a/crates/project/src/project.rs +++ b/crates/project/src/project.rs @@ -2571,8 +2571,8 @@ impl Project { let task = self.open_buffer(path, cx); cx.spawn(async move |_project, cx| { let buffer = task.await?; - let project_entry_id = buffer.read_with(cx, |buffer, cx| { - File::from_dyn(buffer.file()).and_then(|file| file.project_entry_id(cx)) + let project_entry_id = buffer.read_with(cx, |buffer, _cx| { + File::from_dyn(buffer.file()).and_then(|file| file.project_entry_id()) })?; Ok((project_entry_id, buffer)) @@ -5515,8 +5515,8 @@ impl ProjectItem for Buffer { Some(project.update(cx, |project, cx| project.open_buffer(path.clone(), cx))) } - fn entry_id(&self, cx: &App) -> Option { - File::from_dyn(self.file()).and_then(|file| file.project_entry_id(cx)) + fn entry_id(&self, _cx: &App) -> Option { + File::from_dyn(self.file()).and_then(|file| file.project_entry_id()) } fn project_path(&self, cx: &App) -> Option { diff --git a/crates/util/src/paths.rs b/crates/util/src/paths.rs index 8fc62ae1178ad74448590d6dabea5ea421b2b292..d31828eb568978fdcddbf1030badb5911c730004 100644 --- a/crates/util/src/paths.rs +++ b/crates/util/src/paths.rs @@ -4,6 +4,7 @@ use itertools::Itertools; use regex::Regex; use serde::{Deserialize, Serialize}; use std::cmp::Ordering; +use std::error::Error; use std::fmt::{Display, Formatter}; use std::mem; use std::path::StripPrefixError; @@ -184,6 +185,31 @@ impl> PathExt for T { } } +pub fn path_ends_with(base: &Path, suffix: &Path) -> bool { + strip_path_suffix(base, suffix).is_some() +} + +pub fn strip_path_suffix<'a>(base: &'a Path, suffix: &Path) -> Option<&'a Path> { + if let Some(remainder) = base + .as_os_str() + .as_encoded_bytes() + .strip_suffix(suffix.as_os_str().as_encoded_bytes()) + { + if remainder + .last() + .is_none_or(|last_byte| std::path::is_separator(*last_byte as char)) + { + let os_str = unsafe { + OsStr::from_encoded_bytes_unchecked( + &remainder[0..remainder.len().saturating_sub(1)], + ) + }; + return Some(Path::new(os_str)); + } + } + None +} + /// In memory, this is identical to `Path`. On non-Windows conversions to this type are no-ops. On /// windows, these conversions sanitize UNC paths by removing the `\\\\?\\` prefix. #[derive(Eq, PartialEq, Hash, Ord, PartialOrd)] @@ -401,6 +427,82 @@ pub fn is_absolute(path_like: &str, path_style: PathStyle) -> bool { .is_some_and(|path| path.starts_with('/') || path.starts_with('\\'))) } +#[derive(Debug, PartialEq)] +#[non_exhaustive] +pub struct NormalizeError; + +impl Error for NormalizeError {} + +impl std::fmt::Display for NormalizeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("parent reference `..` points outside of base directory") + } +} + +/// Copied from stdlib where it's unstable. +/// +/// Normalize a path, including `..` without traversing the filesystem. +/// +/// Returns an error if normalization would leave leading `..` components. +/// +///
+/// +/// This function always resolves `..` to the "lexical" parent. +/// That is "a/b/../c" will always resolve to `a/c` which can change the meaning of the path. +/// In particular, `a/c` and `a/b/../c` are distinct on many systems because `b` may be a symbolic link, so its parent isn't `a`. +/// +///
+/// +/// [`path::absolute`](absolute) is an alternative that preserves `..`. +/// Or [`Path::canonicalize`] can be used to resolve any `..` by querying the filesystem. +pub fn normalize_lexically(path: &Path) -> Result { + use std::path::Component; + + let mut lexical = PathBuf::new(); + let mut iter = path.components().peekable(); + + // Find the root, if any, and add it to the lexical path. + // Here we treat the Windows path "C:\" as a single "root" even though + // `components` splits it into two: (Prefix, RootDir). + let root = match iter.peek() { + Some(Component::ParentDir) => return Err(NormalizeError), + Some(p @ Component::RootDir) | Some(p @ Component::CurDir) => { + lexical.push(p); + iter.next(); + lexical.as_os_str().len() + } + Some(Component::Prefix(prefix)) => { + lexical.push(prefix.as_os_str()); + iter.next(); + if let Some(p @ Component::RootDir) = iter.peek() { + lexical.push(p); + iter.next(); + } + lexical.as_os_str().len() + } + None => return Ok(PathBuf::new()), + Some(Component::Normal(_)) => 0, + }; + + for component in iter { + match component { + Component::RootDir => unreachable!(), + Component::Prefix(_) => return Err(NormalizeError), + Component::CurDir => continue, + Component::ParentDir => { + // It's an error if ParentDir causes us to go above the "root". + if lexical.as_os_str().len() == root { + return Err(NormalizeError); + } else { + lexical.pop(); + } + } + Component::Normal(path) => lexical.push(path), + } + } + Ok(lexical) +} + /// A delimiter to use in `path_query:row_number:column_number` strings parsing. pub const FILE_ROW_COLUMN_DELIMITER: char = ':'; @@ -1798,4 +1900,35 @@ mod tests { let path = Path::new("/a/b/c/long.app.tar.gz"); assert_eq!(path.multiple_extensions(), Some("app.tar.gz".to_string())); } + + #[test] + fn test_strip_path_suffix() { + let base = Path::new("/a/b/c/file_name"); + let suffix = Path::new("file_name"); + assert_eq!(strip_path_suffix(base, suffix), Some(Path::new("/a/b/c"))); + + let base = Path::new("/a/b/c/file_name.tsx"); + let suffix = Path::new("file_name.tsx"); + assert_eq!(strip_path_suffix(base, suffix), Some(Path::new("/a/b/c"))); + + let base = Path::new("/a/b/c/file_name.stories.tsx"); + let suffix = Path::new("c/file_name.stories.tsx"); + assert_eq!(strip_path_suffix(base, suffix), Some(Path::new("/a/b"))); + + let base = Path::new("/a/b/c/long.app.tar.gz"); + let suffix = Path::new("b/c/long.app.tar.gz"); + assert_eq!(strip_path_suffix(base, suffix), Some(Path::new("/a"))); + + let base = Path::new("/a/b/c/long.app.tar.gz"); + let suffix = Path::new("/a/b/c/long.app.tar.gz"); + assert_eq!(strip_path_suffix(base, suffix), Some(Path::new(""))); + + let base = Path::new("/a/b/c/long.app.tar.gz"); + let suffix = Path::new("/a/b/c/no_match.app.tar.gz"); + assert_eq!(strip_path_suffix(base, suffix), None); + + let base = Path::new("/a/b/c/long.app.tar.gz"); + let suffix = Path::new("app.tar.gz"); + assert_eq!(strip_path_suffix(base, suffix), None); + } } diff --git a/crates/worktree/src/worktree.rs b/crates/worktree/src/worktree.rs index d1f8901f8833191ead7d29d94db7e28e9567fa8e..eb0b5f861d2181f06bcd7732851cf7d397404786 100644 --- a/crates/worktree/src/worktree.rs +++ b/crates/worktree/src/worktree.rs @@ -3154,7 +3154,7 @@ impl File { self.worktree.read(cx).id() } - pub fn project_entry_id(&self, _: &App) -> Option { + pub fn project_entry_id(&self) -> Option { match self.disk_state { DiskState::Deleted => None, _ => self.entry_id, diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index d32b2589135e2d77f9285706254d1e3c05f7f3e2..30ef2d79da05b87c730ccf0c87c4061225d1c723 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -7,8 +7,8 @@ use cloud_llm_client::{ }; use cloud_zeta2_prompt::DEFAULT_MAX_PROMPT_BYTES; use edit_prediction_context::{ - DeclarationId, EditPredictionContext, EditPredictionExcerptOptions, SyntaxIndex, - SyntaxIndexState, + DeclarationId, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions, + EditPredictionExcerptOptions, EditPredictionScoreOptions, SyntaxIndex, SyntaxIndexState, }; use futures::AsyncReadExt as _; use futures::channel::mpsc; @@ -43,14 +43,20 @@ const BUFFER_CHANGE_GROUPING_INTERVAL: Duration = Duration::from_secs(1); /// Maximum number of events to track. const MAX_EVENT_COUNT: usize = 16; -pub const DEFAULT_EXCERPT_OPTIONS: EditPredictionExcerptOptions = EditPredictionExcerptOptions { - max_bytes: 512, - min_bytes: 128, - target_before_cursor_over_total_bytes: 0.5, +pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPredictionContextOptions { + use_imports: true, + excerpt: EditPredictionExcerptOptions { + max_bytes: 512, + min_bytes: 128, + target_before_cursor_over_total_bytes: 0.5, + }, + score: EditPredictionScoreOptions { + omit_excerpt_overlaps: true, + }, }; pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions { - excerpt: DEFAULT_EXCERPT_OPTIONS, + context: DEFAULT_CONTEXT_OPTIONS, max_prompt_bytes: DEFAULT_MAX_PROMPT_BYTES, max_diagnostic_bytes: 2048, prompt_format: PromptFormat::DEFAULT, @@ -75,7 +81,7 @@ pub struct Zeta { #[derive(Debug, Clone, PartialEq)] pub struct ZetaOptions { - pub excerpt: EditPredictionExcerptOptions, + pub context: EditPredictionContextOptions, pub max_prompt_bytes: usize, pub max_diagnostic_bytes: usize, pub prompt_format: predict_edits_v3::PromptFormat, @@ -501,6 +507,11 @@ impl Zeta { let diagnostics = snapshot.diagnostic_sets().clone(); + let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| { + let mut path = f.worktree.read(cx).absolutize(&f.path); + if path.pop() { Some(path) } else { None } + }); + let request_task = cx.background_spawn({ let snapshot = snapshot.clone(); let buffer = buffer.clone(); @@ -519,7 +530,8 @@ impl Zeta { let Some(context) = EditPredictionContext::gather_context( cursor_point, &snapshot, - &options.excerpt, + parent_abs_path.as_deref(), + &options.context, index_state.as_deref(), ) else { return Ok(None); @@ -785,6 +797,11 @@ impl Zeta { .map(|worktree| worktree.read(cx).snapshot()) .collect::>(); + let parent_abs_path = project::File::from_dyn(buffer.read(cx).file()).and_then(|f| { + let mut path = f.worktree.read(cx).absolutize(&f.path); + if path.pop() { Some(path) } else { None } + }); + cx.background_spawn(async move { let index_state = if let Some(index_state) = index_state { Some(index_state.lock_owned().await) @@ -798,7 +815,8 @@ impl Zeta { EditPredictionContext::gather_context( cursor_point, &snapshot, - &options.excerpt, + parent_abs_path.as_deref(), + &options.context, index_state.as_deref(), ) .context("Failed to select excerpt") @@ -893,9 +911,9 @@ fn make_cloud_request( text_is_truncated, signature_range: snippet.declaration.signature_range_in_item_text(), parent_index, - score_components: snippet.score_components, - signature_score: snippet.scores.signature, - declaration_score: snippet.scores.declaration, + signature_score: snippet.score(DeclarationStyle::Signature), + declaration_score: snippet.score(DeclarationStyle::Declaration), + score_components: snippet.components, }); } diff --git a/crates/zeta2_tools/src/zeta2_tools.rs b/crates/zeta2_tools/src/zeta2_tools.rs index a299726c64b104e59cab9ba4609316c49d715876..40315265df4c9a4aec3dfee37185d94249841eda 100644 --- a/crates/zeta2_tools/src/zeta2_tools.rs +++ b/crates/zeta2_tools/src/zeta2_tools.rs @@ -16,7 +16,7 @@ use ui::{ContextMenu, ContextMenuEntry, DropdownMenu, prelude::*}; use ui_input::SingleLineInput; use util::{ResultExt, paths::PathStyle, rel_path::RelPath}; use workspace::{Item, SplitDirection, Workspace}; -use zeta2::{Zeta, ZetaOptions}; +use zeta2::{DEFAULT_CONTEXT_OPTIONS, Zeta, ZetaOptions}; use edit_prediction_context::{DeclarationStyle, EditPredictionExcerptOptions}; @@ -146,16 +146,19 @@ impl Zeta2Inspector { cx: &mut Context, ) { self.max_excerpt_bytes_input.update(cx, |input, cx| { - input.set_text(options.excerpt.max_bytes.to_string(), window, cx); + input.set_text(options.context.excerpt.max_bytes.to_string(), window, cx); }); self.min_excerpt_bytes_input.update(cx, |input, cx| { - input.set_text(options.excerpt.min_bytes.to_string(), window, cx); + input.set_text(options.context.excerpt.min_bytes.to_string(), window, cx); }); self.cursor_context_ratio_input.update(cx, |input, cx| { input.set_text( format!( "{:.2}", - options.excerpt.target_before_cursor_over_total_bytes + options + .context + .excerpt + .target_before_cursor_over_total_bytes ), window, cx, @@ -236,7 +239,8 @@ impl Zeta2Inspector { .unwrap_or_default() } - let excerpt_options = EditPredictionExcerptOptions { + let mut context_options = DEFAULT_CONTEXT_OPTIONS.clone(); + context_options.excerpt = EditPredictionExcerptOptions { max_bytes: number_input_value(&this.max_excerpt_bytes_input, cx), min_bytes: number_input_value(&this.min_excerpt_bytes_input, cx), target_before_cursor_over_total_bytes: number_input_value( @@ -248,7 +252,7 @@ impl Zeta2Inspector { let zeta_options = this.zeta.read(cx).options(); this.set_options( ZetaOptions { - excerpt: excerpt_options, + context: context_options, max_prompt_bytes: number_input_value(&this.max_prompt_bytes_input, cx), max_diagnostic_bytes: zeta_options.max_diagnostic_bytes, prompt_format: zeta_options.prompt_format, diff --git a/crates/zeta_cli/Cargo.toml b/crates/zeta_cli/Cargo.toml index 660de610c14ae3926b787e136d3aa9779156c279..d81a5ae6d34fbe7cba25898fc4885baa84f1dfb2 100644 --- a/crates/zeta_cli/Cargo.toml +++ b/crates/zeta_cli/Cargo.toml @@ -18,6 +18,7 @@ clap.workspace = true client.workspace = true cloud_llm_client.workspace= true cloud_zeta2_prompt.workspace= true +collections.workspace = true debug_adapter_extension.workspace = true edit_prediction_context.workspace = true extension.workspace = true @@ -32,6 +33,7 @@ language_models.workspace = true languages = { workspace = true, features = ["load-grammars"] } log.workspace = true node_runtime.workspace = true +ordered-float.workspace = true paths.workspace = true project.workspace = true prompt_store.workspace = true @@ -49,4 +51,3 @@ workspace-hack.workspace = true zeta.workspace = true zeta2.workspace = true zlog.workspace = true -ordered-float.workspace = true diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index feaf3999dcac8c74783d4a74e408ce0812bc89cd..236c5eb4572cf451a3efd435b9d0ad20d4380b72 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -1,33 +1,40 @@ mod headless; -use anyhow::{Result, anyhow}; +use anyhow::{Context as _, Result, anyhow}; use clap::{Args, Parser, Subcommand}; -use cloud_llm_client::predict_edits_v3; +use cloud_llm_client::predict_edits_v3::{self, DeclarationScoreComponents}; use edit_prediction_context::{ - Declaration, EditPredictionContext, EditPredictionExcerptOptions, Identifier, ReferenceRegion, - SyntaxIndex, references_in_range, + Declaration, DeclarationStyle, EditPredictionContext, EditPredictionContextOptions, + EditPredictionExcerptOptions, EditPredictionScoreOptions, Identifier, Imports, Reference, + ReferenceRegion, SyntaxIndex, SyntaxIndexState, references_in_range, }; use futures::channel::mpsc; use futures::{FutureExt as _, StreamExt as _}; use gpui::{AppContext, Application, AsyncApp}; use gpui::{Entity, Task}; -use language::{Bias, LanguageServerId}; +use language::{Bias, BufferSnapshot, LanguageServerId, Point}; use language::{Buffer, OffsetRangeExt}; -use language::{LanguageId, Point}; +use language::{LanguageId, ParseStatus}; use language_model::LlmApiToken; use ordered_float::OrderedFloat; -use project::{Project, ProjectPath, Worktree}; +use project::{Project, ProjectEntryId, ProjectPath, Worktree}; use release_channel::AppVersion; use reqwest_client::ReqwestClient; +use serde::{Deserialize, Deserializer, Serialize, Serializer}; use serde_json::json; use std::cmp::Reverse; use std::collections::{HashMap, HashSet}; +use std::fmt::{self, Display}; +use std::fs::File; +use std::hash::Hash; +use std::hash::Hasher; use std::io::Write as _; use std::ops::Range; use std::path::{Path, PathBuf}; use std::process::exit; use std::str::FromStr; -use std::sync::Arc; +use std::sync::atomic::AtomicUsize; +use std::sync::{Arc, atomic}; use std::time::Duration; use util::paths::PathStyle; use util::rel_path::RelPath; @@ -59,10 +66,16 @@ enum Commands { context_args: Option, }, RetrievalStats { + #[clap(flatten)] + zeta2_args: Zeta2Args, #[arg(long)] worktree: PathBuf, - #[arg(long, default_value_t = 42)] - file_indexing_parallelism: usize, + #[arg(long)] + extension: Option, + #[arg(long)] + limit: Option, + #[arg(long)] + skip: Option, }, } @@ -72,7 +85,7 @@ struct ContextArgs { #[arg(long)] worktree: PathBuf, #[arg(long)] - cursor: CursorPosition, + cursor: SourceLocation, #[arg(long)] use_language_server: bool, #[arg(long)] @@ -97,6 +110,8 @@ struct Zeta2Args { output_format: OutputFormat, #[arg(long, default_value_t = 42)] file_indexing_parallelism: usize, + #[arg(long, default_value_t = false)] + disable_imports_gathering: bool, } #[derive(clap::ValueEnum, Default, Debug, Clone)] @@ -151,20 +166,51 @@ impl FromStr for FileOrStdin { } } -#[derive(Debug, Clone)] -struct CursorPosition { +#[derive(Debug, Clone, Hash, Eq, PartialEq)] +struct SourceLocation { path: Arc, point: Point, } -impl FromStr for CursorPosition { +impl Serialize for SourceLocation { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(&self.to_string()) + } +} + +impl<'de> Deserialize<'de> for SourceLocation { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + s.parse().map_err(serde::de::Error::custom) + } +} + +impl Display for SourceLocation { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}:{}:{}", + self.path.display(PathStyle::Posix), + self.point.row + 1, + self.point.column + 1 + ) + } +} + +impl FromStr for SourceLocation { type Err = anyhow::Error; fn from_str(s: &str) -> Result { let parts: Vec<&str> = s.split(':').collect(); if parts.len() != 3 { return Err(anyhow!( - "Invalid cursor format. Expected 'file.rs:line:column', got '{}'", + "Invalid source location. Expected 'file.rs:line:column', got '{}'", s )); } @@ -180,7 +226,7 @@ impl FromStr for CursorPosition { // Convert from 1-based to 0-based indexing let point = Point::new(line.saturating_sub(1), column.saturating_sub(1)); - Ok(CursorPosition { path, point }) + Ok(SourceLocation { path, point }) } } @@ -225,16 +271,17 @@ async fn get_context( let mut ready_languages = HashSet::default(); let (_lsp_open_handle, buffer) = if use_language_server { let (lsp_open_handle, _, buffer) = open_buffer_with_language_server( - &project, - &worktree, - &cursor.path, + project.clone(), + worktree.clone(), + cursor.path.clone(), &mut ready_languages, cx, ) .await?; (Some(lsp_open_handle), buffer) } else { - let buffer = open_buffer(&project, &worktree, &cursor.path, cx).await?; + let buffer = + open_buffer(project.clone(), worktree.clone(), cursor.path.clone(), cx).await?; (None, buffer) }; @@ -281,18 +328,7 @@ async fn get_context( zeta2::Zeta::new(app_state.client.clone(), app_state.user_store.clone(), cx) }); let indexing_done_task = zeta.update(cx, |zeta, cx| { - zeta.set_options(zeta2::ZetaOptions { - excerpt: EditPredictionExcerptOptions { - max_bytes: zeta2_args.max_excerpt_bytes, - min_bytes: zeta2_args.min_excerpt_bytes, - target_before_cursor_over_total_bytes: zeta2_args - .target_before_cursor_over_total_bytes, - }, - max_diagnostic_bytes: zeta2_args.max_diagnostic_bytes, - max_prompt_bytes: zeta2_args.max_prompt_bytes, - prompt_format: zeta2_args.prompt_format.into(), - file_indexing_parallelism: zeta2_args.file_indexing_parallelism, - }); + zeta.set_options(zeta2_args.to_options(true)); zeta.register_buffer(&buffer, &project, cx); zeta.wait_for_initial_indexing(&project, cx) }); @@ -340,12 +376,39 @@ async fn get_context( } } +impl Zeta2Args { + fn to_options(&self, omit_excerpt_overlaps: bool) -> zeta2::ZetaOptions { + zeta2::ZetaOptions { + context: EditPredictionContextOptions { + use_imports: !self.disable_imports_gathering, + excerpt: EditPredictionExcerptOptions { + max_bytes: self.max_excerpt_bytes, + min_bytes: self.min_excerpt_bytes, + target_before_cursor_over_total_bytes: self + .target_before_cursor_over_total_bytes, + }, + score: EditPredictionScoreOptions { + omit_excerpt_overlaps, + }, + }, + max_diagnostic_bytes: self.max_diagnostic_bytes, + max_prompt_bytes: self.max_prompt_bytes, + prompt_format: self.prompt_format.clone().into(), + file_indexing_parallelism: self.file_indexing_parallelism, + } + } +} + pub async fn retrieval_stats( worktree: PathBuf, - file_indexing_parallelism: usize, app_state: Arc, + only_extension: Option, + file_limit: Option, + skip_files: Option, + options: zeta2::ZetaOptions, cx: &mut AsyncApp, ) -> Result { + let options = Arc::new(options); let worktree_path = worktree.canonicalize()?; let project = cx.update(|cx| { @@ -365,7 +428,6 @@ pub async fn retrieval_stats( project.create_worktree(&worktree_path, true, cx) })? .await?; - let worktree_id = worktree.read_with(cx, |worktree, _cx| worktree.id())?; // wait for worktree scan so that wait_for_initial_file_indexing waits for the whole worktree. worktree @@ -374,21 +436,492 @@ pub async fn retrieval_stats( })? .await; - let index = cx.new(|cx| SyntaxIndex::new(&project, file_indexing_parallelism, cx))?; + let index = cx.new(|cx| SyntaxIndex::new(&project, options.file_indexing_parallelism, cx))?; index .read_with(cx, |index, cx| index.wait_for_initial_file_indexing(cx))? .await?; - let files = index + let indexed_files = index .read_with(cx, |index, cx| index.indexed_file_paths(cx))? - .await + .await; + let mut filtered_files = indexed_files .into_iter() .filter(|project_path| { - project_path - .path - .extension() - .is_some_and(|extension| !["md", "json", "sh", "diff"].contains(&extension)) + let file_extension = project_path.path.extension(); + if let Some(only_extension) = only_extension.as_ref() { + file_extension.is_some_and(|extension| extension == only_extension) + } else { + file_extension + .is_some_and(|extension| !["md", "json", "sh", "diff"].contains(&extension)) + } }) .collect::>(); + filtered_files.sort_by(|a, b| a.path.cmp(&b.path)); + + let index_state = index.read_with(cx, |index, _cx| index.state().clone())?; + cx.update(|_| { + drop(index); + })?; + let index_state = Arc::new( + Arc::into_inner(index_state) + .context("Index state had more than 1 reference")? + .into_inner(), + ); + + struct FileSnapshot { + project_entry_id: ProjectEntryId, + snapshot: BufferSnapshot, + hash: u64, + parent_abs_path: Arc, + } + + let files: Vec = futures::future::try_join_all({ + filtered_files + .iter() + .map(|file| { + let buffer_task = + open_buffer(project.clone(), worktree.clone(), file.path.clone(), cx); + cx.spawn(async move |cx| { + let buffer = buffer_task.await?; + let (project_entry_id, parent_abs_path, snapshot) = + buffer.read_with(cx, |buffer, cx| { + let file = project::File::from_dyn(buffer.file()).unwrap(); + let project_entry_id = file.project_entry_id().unwrap(); + let mut parent_abs_path = file.worktree.read(cx).absolutize(&file.path); + if !parent_abs_path.pop() { + panic!("Invalid worktree path"); + } + + (project_entry_id, parent_abs_path, buffer.snapshot()) + })?; + + anyhow::Ok( + cx.background_spawn(async move { + let mut hasher = collections::FxHasher::default(); + snapshot.text().hash(&mut hasher); + FileSnapshot { + project_entry_id, + snapshot, + hash: hasher.finish(), + parent_abs_path: parent_abs_path.into(), + } + }) + .await, + ) + }) + }) + .collect::>() + }) + .await?; + + let mut file_snapshots = HashMap::default(); + let mut hasher = collections::FxHasher::default(); + for FileSnapshot { + project_entry_id, + snapshot, + hash, + .. + } in &files + { + file_snapshots.insert(*project_entry_id, snapshot.clone()); + hash.hash(&mut hasher); + } + let files_hash = hasher.finish(); + let file_snapshots = Arc::new(file_snapshots); + + let lsp_definitions_path = std::env::current_dir()?.join(format!( + "target/zeta2-lsp-definitions-{:x}.json", + files_hash + )); + + let lsp_definitions: Arc<_> = if std::fs::exists(&lsp_definitions_path)? { + log::info!( + "Using cached LSP definitions from {}", + lsp_definitions_path.display() + ); + serde_json::from_reader(File::open(&lsp_definitions_path)?)? + } else { + log::warn!( + "No LSP definitions found populating {}", + lsp_definitions_path.display() + ); + let lsp_definitions = + gather_lsp_definitions(&filtered_files, &worktree, &project, cx).await?; + serde_json::to_writer_pretty(File::create(&lsp_definitions_path)?, &lsp_definitions)?; + lsp_definitions + } + .into(); + + let files_len = files.len().min(file_limit.unwrap_or(usize::MAX)); + let done_count = Arc::new(AtomicUsize::new(0)); + + let (output_tx, mut output_rx) = mpsc::unbounded::(); + let mut output = std::fs::File::create("target/zeta-retrieval-stats.txt")?; + + let tasks = files + .into_iter() + .skip(skip_files.unwrap_or(0)) + .take(file_limit.unwrap_or(usize::MAX)) + .map(|project_file| { + let index_state = index_state.clone(); + let lsp_definitions = lsp_definitions.clone(); + let options = options.clone(); + let output_tx = output_tx.clone(); + let done_count = done_count.clone(); + let file_snapshots = file_snapshots.clone(); + cx.background_spawn(async move { + let snapshot = project_file.snapshot; + + let full_range = 0..snapshot.len(); + let references = references_in_range( + full_range, + &snapshot.text(), + ReferenceRegion::Nearby, + &snapshot, + ); + + println!("references: {}", references.len(),); + + let imports = if options.context.use_imports { + Imports::gather(&snapshot, Some(&project_file.parent_abs_path)) + } else { + Imports::default() + }; + + let path = snapshot.file().unwrap().path(); + + for reference in references { + let query_point = snapshot.offset_to_point(reference.range.start); + let source_location = SourceLocation { + path: path.clone(), + point: query_point, + }; + let lsp_definitions = lsp_definitions + .definitions + .get(&source_location) + .cloned() + .unwrap_or_else(|| { + log::warn!( + "No definitions found for source location: {:?}", + source_location + ); + Vec::new() + }); + + let retrieve_result = retrieve_definitions( + &reference, + &imports, + query_point, + &snapshot, + &index_state, + &file_snapshots, + &options, + ) + .await?; + + // TODO: LSP returns things like locals, this filters out some of those, but potentially + // hides some retrieval issues. + if retrieve_result.definitions.is_empty() { + continue; + } + + let mut best_match = None; + let mut has_external_definition = false; + let mut in_excerpt = false; + for (index, retrieved_definition) in + retrieve_result.definitions.iter().enumerate() + { + for lsp_definition in &lsp_definitions { + let SourceRange { + path, + point_range, + offset_range, + } = lsp_definition; + let lsp_point_range = + SerializablePoint::into_language_point_range(point_range.clone()); + has_external_definition = has_external_definition + || path.is_absolute() + || path + .components() + .any(|component| component.as_os_str() == "node_modules"); + let is_match = path.as_path() + == retrieved_definition.path.as_std_path() + && retrieved_definition + .range + .contains_inclusive(&lsp_point_range); + if is_match { + if best_match.is_none() { + best_match = Some(index); + } + } + in_excerpt = in_excerpt + || retrieve_result.excerpt_range.as_ref().is_some_and( + |excerpt_range| excerpt_range.contains_inclusive(&offset_range), + ); + } + } + + let outcome = if let Some(best_match) = best_match { + RetrievalOutcome::Match { best_match } + } else if has_external_definition { + RetrievalOutcome::NoMatchDueToExternalLspDefinitions + } else if in_excerpt { + RetrievalOutcome::ProbablyLocal + } else { + RetrievalOutcome::NoMatch + }; + + let result = RetrievalStatsResult { + outcome, + path: path.clone(), + identifier: reference.identifier, + point: query_point, + lsp_definitions, + retrieved_definitions: retrieve_result.definitions, + }; + + output_tx.unbounded_send(result).ok(); + } + + println!( + "{:02}/{:02} done", + done_count.fetch_add(1, atomic::Ordering::Relaxed) + 1, + files_len, + ); + + anyhow::Ok(()) + }) + }) + .collect::>(); + + drop(output_tx); + + let results_task = cx.background_spawn(async move { + let mut results = Vec::new(); + while let Some(result) = output_rx.next().await { + output + .write_all(format!("{:#?}\n", result).as_bytes()) + .log_err(); + results.push(result) + } + results + }); + + futures::future::try_join_all(tasks).await?; + println!("Tasks completed"); + let results = results_task.await; + println!("Results received"); + + let mut references_count = 0; + + let mut included_count = 0; + let mut both_absent_count = 0; + + let mut retrieved_count = 0; + let mut top_match_count = 0; + let mut non_top_match_count = 0; + let mut ranking_involved_top_match_count = 0; + + let mut no_match_count = 0; + let mut no_match_none_retrieved = 0; + let mut no_match_wrong_retrieval = 0; + + let mut expected_no_match_count = 0; + let mut in_excerpt_count = 0; + let mut external_definition_count = 0; + + for result in results { + references_count += 1; + match &result.outcome { + RetrievalOutcome::Match { best_match } => { + included_count += 1; + retrieved_count += 1; + let multiple = result.retrieved_definitions.len() > 1; + if *best_match == 0 { + top_match_count += 1; + if multiple { + ranking_involved_top_match_count += 1; + } + } else { + non_top_match_count += 1; + } + } + RetrievalOutcome::NoMatch => { + if result.lsp_definitions.is_empty() { + included_count += 1; + both_absent_count += 1; + } else { + no_match_count += 1; + if result.retrieved_definitions.is_empty() { + no_match_none_retrieved += 1; + } else { + no_match_wrong_retrieval += 1; + } + } + } + RetrievalOutcome::NoMatchDueToExternalLspDefinitions => { + expected_no_match_count += 1; + external_definition_count += 1; + } + RetrievalOutcome::ProbablyLocal => { + included_count += 1; + in_excerpt_count += 1; + } + } + } + + fn count_and_percentage(part: usize, total: usize) -> String { + format!("{} ({:.2}%)", part, (part as f64 / total as f64) * 100.0) + } + + println!(""); + println!("╮ references: {}", references_count); + println!( + "├─╮ included: {}", + count_and_percentage(included_count, references_count), + ); + println!( + "│ ├─╮ retrieved: {}", + count_and_percentage(retrieved_count, references_count) + ); + println!( + "│ │ ├─╮ top match : {}", + count_and_percentage(top_match_count, retrieved_count) + ); + println!( + "│ │ │ ╰─╴ involving ranking: {}", + count_and_percentage(ranking_involved_top_match_count, top_match_count) + ); + println!( + "│ │ ╰─╴ non-top match: {}", + count_and_percentage(non_top_match_count, retrieved_count) + ); + println!( + "│ ├─╴ both absent: {}", + count_and_percentage(both_absent_count, included_count) + ); + println!( + "│ ╰─╴ in excerpt: {}", + count_and_percentage(in_excerpt_count, included_count) + ); + println!( + "├─╮ no match: {}", + count_and_percentage(no_match_count, references_count) + ); + println!( + "│ ├─╴ none retrieved: {}", + count_and_percentage(no_match_none_retrieved, no_match_count) + ); + println!( + "│ ╰─╴ wrong retrieval: {}", + count_and_percentage(no_match_wrong_retrieval, no_match_count) + ); + println!( + "╰─╮ expected no match: {}", + count_and_percentage(expected_no_match_count, references_count) + ); + println!( + " ╰─╴ external definition: {}", + count_and_percentage(external_definition_count, expected_no_match_count) + ); + + println!(""); + println!("LSP definition cache at {}", lsp_definitions_path.display()); + + Ok("".to_string()) +} + +struct RetrieveResult { + definitions: Vec, + excerpt_range: Option>, +} + +async fn retrieve_definitions( + reference: &Reference, + imports: &Imports, + query_point: Point, + snapshot: &BufferSnapshot, + index: &Arc, + file_snapshots: &Arc>, + options: &Arc, +) -> Result { + let mut single_reference_map = HashMap::default(); + single_reference_map.insert(reference.identifier.clone(), vec![reference.clone()]); + let edit_prediction_context = EditPredictionContext::gather_context_with_references_fn( + query_point, + snapshot, + imports, + &options.context, + Some(&index), + |_, _, _| single_reference_map, + ); + + let Some(edit_prediction_context) = edit_prediction_context else { + return Ok(RetrieveResult { + definitions: Vec::new(), + excerpt_range: None, + }); + }; + + let mut retrieved_definitions = Vec::new(); + for scored_declaration in edit_prediction_context.declarations { + match &scored_declaration.declaration { + Declaration::File { + project_entry_id, + declaration, + .. + } => { + let Some(snapshot) = file_snapshots.get(&project_entry_id) else { + log::error!("bug: file project entry not found"); + continue; + }; + let path = snapshot.file().unwrap().path().clone(); + retrieved_definitions.push(RetrievedDefinition { + path, + range: snapshot.offset_to_point(declaration.item_range.start) + ..snapshot.offset_to_point(declaration.item_range.end), + score: scored_declaration.score(DeclarationStyle::Declaration), + retrieval_score: scored_declaration.retrieval_score(), + components: scored_declaration.components, + }); + } + Declaration::Buffer { + project_entry_id, + rope, + declaration, + .. + } => { + let Some(snapshot) = file_snapshots.get(&project_entry_id) else { + // This case happens when dependency buffers have been opened by + // go-to-definition, resulting in single-file worktrees. + continue; + }; + let path = snapshot.file().unwrap().path().clone(); + retrieved_definitions.push(RetrievedDefinition { + path, + range: rope.offset_to_point(declaration.item_range.start) + ..rope.offset_to_point(declaration.item_range.end), + score: scored_declaration.score(DeclarationStyle::Declaration), + retrieval_score: scored_declaration.retrieval_score(), + components: scored_declaration.components, + }); + } + } + } + retrieved_definitions.sort_by_key(|definition| Reverse(OrderedFloat(definition.score))); + + Ok(RetrieveResult { + definitions: retrieved_definitions, + excerpt_range: Some(edit_prediction_context.excerpt.range), + }) +} + +async fn gather_lsp_definitions( + files: &[ProjectPath], + worktree: &Entity, + project: &Entity, + cx: &mut AsyncApp, +) -> Result { + let worktree_id = worktree.read_with(cx, |worktree, _cx| worktree.id())?; let lsp_store = project.read_with(cx, |project, _cx| project.lsp_store())?; cx.subscribe(&lsp_store, { @@ -410,24 +943,22 @@ pub async fn retrieval_stats( })? .detach(); + let mut definitions = HashMap::default(); + let mut error_count = 0; let mut lsp_open_handles = Vec::new(); - let mut output = std::fs::File::create("retrieval-stats.txt")?; - let mut results = Vec::new(); let mut ready_languages = HashSet::default(); for (file_index, project_path) in files.iter().enumerate() { - let processing_file_message = format!( + println!( "Processing file {} of {}: {}", file_index + 1, files.len(), project_path.path.display(PathStyle::Posix) ); - println!("{}", processing_file_message); - write!(output, "{processing_file_message}\n\n").ok(); let Some((lsp_open_handle, language_server_id, buffer)) = open_buffer_with_language_server( - &project, - &worktree, - &project_path.path, + project.clone(), + worktree.clone(), + project_path.path.clone(), &mut ready_languages, cx, ) @@ -463,273 +994,182 @@ pub async fn retrieval_stats( .await; } - let index = index.read_with(cx, |index, _cx| index.state().clone())?; - let index = index.lock().await; for reference in references { - let query_point = snapshot.offset_to_point(reference.range.start); - let mut single_reference_map = HashMap::default(); - single_reference_map.insert(reference.identifier.clone(), vec![reference.clone()]); - let edit_prediction_context = EditPredictionContext::gather_context_with_references_fn( - query_point, - &snapshot, - &zeta2::DEFAULT_EXCERPT_OPTIONS, - Some(&index), - |_, _, _| single_reference_map, - ); - - let Some(edit_prediction_context) = edit_prediction_context else { - let result = RetrievalStatsResult { - identifier: reference.identifier, - point: query_point, - outcome: RetrievalStatsOutcome::NoExcerpt, - }; - write!(output, "{:?}\n\n", result)?; - results.push(result); - continue; - }; - - let mut retrieved_definitions = Vec::new(); - for scored_declaration in edit_prediction_context.declarations { - match &scored_declaration.declaration { - Declaration::File { - project_entry_id, - declaration, - } => { - let Some(path) = worktree.read_with(cx, |worktree, _cx| { - worktree - .entry_for_id(*project_entry_id) - .map(|entry| entry.path.clone()) - })? - else { - log::error!("bug: file project entry not found"); - continue; - }; - let project_path = ProjectPath { - worktree_id, - path: path.clone(), - }; - let buffer = project - .update(cx, |project, cx| project.open_buffer(project_path, cx))? - .await?; - let rope = buffer.read_with(cx, |buffer, _cx| buffer.as_rope().clone())?; - retrieved_definitions.push(( - path, - rope.offset_to_point(declaration.item_range.start) - ..rope.offset_to_point(declaration.item_range.end), - scored_declaration.scores.declaration, - scored_declaration.scores.retrieval, - )); - } - Declaration::Buffer { - project_entry_id, - rope, - declaration, - .. - } => { - let Some(path) = worktree.read_with(cx, |worktree, _cx| { - worktree - .entry_for_id(*project_entry_id) - .map(|entry| entry.path.clone()) - })? - else { - // This case happens when dependency buffers have been opened by - // go-to-definition, resulting in single-file worktrees. - continue; - }; - retrieved_definitions.push(( - path, - rope.offset_to_point(declaration.item_range.start) - ..rope.offset_to_point(declaration.item_range.end), - scored_declaration.scores.declaration, - scored_declaration.scores.retrieval, - )); - } - } - } - retrieved_definitions - .sort_by_key(|(_, _, _, retrieval_score)| Reverse(OrderedFloat(*retrieval_score))); - - // TODO: Consider still checking language server in this case, or having a mode for - // this. For now assuming that the purpose of this is to refine the ranking rather than - // refining whether the definition is present at all. - if retrieved_definitions.is_empty() { - continue; - } - // TODO: Rename declaration to definition in edit_prediction_context? let lsp_result = project .update(cx, |project, cx| { project.definitions(&buffer, reference.range.start, cx) })? .await; + match lsp_result { Ok(lsp_definitions) => { - let lsp_definitions = lsp_definitions - .unwrap_or_default() - .into_iter() - .filter_map(|definition| { - definition - .target - .buffer - .read_with(cx, |buffer, _cx| { - let path = buffer.file()?.path(); - // filter out definitions from single-file worktrees - if path.is_empty() { - None - } else { - Some(( - path.clone(), - definition.target.range.to_point(&buffer), - )) - } - }) - .ok()? - }) - .collect::>(); + let mut targets = Vec::new(); + for target in lsp_definitions.unwrap_or_default() { + let buffer = target.target.buffer; + let anchor_range = target.target.range; + buffer.read_with(cx, |buffer, cx| { + let Some(file) = project::File::from_dyn(buffer.file()) else { + return; + }; + let file_worktree = file.worktree.read(cx); + let file_worktree_id = file_worktree.id(); + // Relative paths for worktree files, absolute for all others + let path = if worktree_id != file_worktree_id { + file.worktree.read(cx).absolutize(&file.path) + } else { + file.path.as_std_path().to_path_buf() + }; + let offset_range = anchor_range.to_offset(&buffer); + let point_range = SerializablePoint::from_language_point_range( + offset_range.to_point(&buffer), + ); + targets.push(SourceRange { + path, + offset_range, + point_range, + }); + })?; + } - let result = RetrievalStatsResult { - identifier: reference.identifier, - point: query_point, - outcome: RetrievalStatsOutcome::Success { - matches: lsp_definitions - .iter() - .map(|(path, range)| { - retrieved_definitions.iter().position( - |(retrieved_path, retrieved_range, _, _)| { - path == retrieved_path - && retrieved_range.contains_inclusive(&range) - }, - ) - }) - .collect(), - lsp_definitions, - retrieved_definitions, + definitions.insert( + SourceLocation { + path: project_path.path.clone(), + point: snapshot.offset_to_point(reference.range.start), }, - }; - write!(output, "{:?}\n\n", result)?; - results.push(result); + targets, + ); } Err(err) => { - let result = RetrievalStatsResult { - identifier: reference.identifier, - point: query_point, - outcome: RetrievalStatsOutcome::LanguageServerError { - message: err.to_string(), - }, - }; - write!(output, "{:?}\n\n", result)?; - results.push(result); + log::error!("Language server error: {err}"); + error_count += 1; } } } } - let mut no_excerpt_count = 0; - let mut error_count = 0; - let mut definitions_count = 0; - let mut top_match_count = 0; - let mut non_top_match_count = 0; - let mut ranking_involved_count = 0; - let mut ranking_involved_top_match_count = 0; - let mut ranking_involved_non_top_match_count = 0; - for result in &results { - match &result.outcome { - RetrievalStatsOutcome::NoExcerpt => no_excerpt_count += 1, - RetrievalStatsOutcome::LanguageServerError { .. } => error_count += 1, - RetrievalStatsOutcome::Success { - matches, - retrieved_definitions, - .. - } => { - definitions_count += 1; - let top_matches = matches.contains(&Some(0)); - if top_matches { - top_match_count += 1; - } - let non_top_matches = !top_matches && matches.iter().any(|index| *index != Some(0)); - if non_top_matches { - non_top_match_count += 1; - } - if retrieved_definitions.len() > 1 { - ranking_involved_count += 1; - if top_matches { - ranking_involved_top_match_count += 1; - } - if non_top_matches { - ranking_involved_non_top_match_count += 1; - } - } - } - } + log::error!("Encountered {} language server errors", error_count); + + Ok(LspResults { definitions }) +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(transparent)] +struct LspResults { + definitions: HashMap>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct SourceRange { + path: PathBuf, + point_range: Range, + offset_range: Range, +} + +/// Serializes to 1-based row and column indices. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SerializablePoint { + pub row: u32, + pub column: u32, +} + +impl SerializablePoint { + pub fn into_language_point_range(range: Range) -> Range { + range.start.into()..range.end.into() } - println!("\nStats:\n"); - println!("No Excerpt: {}", no_excerpt_count); - println!("Language Server Error: {}", error_count); - println!("Definitions: {}", definitions_count); - println!("Top Match: {}", top_match_count); - println!("Non-Top Match: {}", non_top_match_count); - println!("Ranking Involved: {}", ranking_involved_count); - println!( - "Ranking Involved Top Match: {}", - ranking_involved_top_match_count - ); - println!( - "Ranking Involved Non-Top Match: {}", - ranking_involved_non_top_match_count - ); + pub fn from_language_point_range(range: Range) -> Range { + range.start.into()..range.end.into() + } +} - Ok("".to_string()) +impl From for SerializablePoint { + fn from(point: Point) -> Self { + SerializablePoint { + row: point.row + 1, + column: point.column + 1, + } + } +} + +impl From for Point { + fn from(serializable: SerializablePoint) -> Self { + Point { + row: serializable.row.saturating_sub(1), + column: serializable.column.saturating_sub(1), + } + } } #[derive(Debug)] struct RetrievalStatsResult { + outcome: RetrievalOutcome, + #[allow(dead_code)] + path: Arc, #[allow(dead_code)] identifier: Identifier, #[allow(dead_code)] point: Point, - outcome: RetrievalStatsOutcome, + #[allow(dead_code)] + lsp_definitions: Vec, + retrieved_definitions: Vec, } #[derive(Debug)] -enum RetrievalStatsOutcome { - NoExcerpt, - LanguageServerError { - #[allow(dead_code)] - message: String, - }, - Success { - matches: Vec>, - #[allow(dead_code)] - lsp_definitions: Vec<(Arc, Range)>, - retrieved_definitions: Vec<(Arc, Range, f32, f32)>, +enum RetrievalOutcome { + Match { + /// Lowest index within retrieved_definitions that matches an LSP definition. + best_match: usize, }, + ProbablyLocal, + NoMatch, + NoMatchDueToExternalLspDefinitions, } -pub async fn open_buffer( - project: &Entity, - worktree: &Entity, - path: &RelPath, - cx: &mut AsyncApp, -) -> Result> { - let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath { - worktree_id: worktree.id(), - path: path.into(), - })?; +#[derive(Debug)] +struct RetrievedDefinition { + path: Arc, + range: Range, + score: f32, + #[allow(dead_code)] + retrieval_score: f32, + #[allow(dead_code)] + components: DeclarationScoreComponents, +} - project - .update(cx, |project, cx| project.open_buffer(project_path, cx))? - .await +pub fn open_buffer( + project: Entity, + worktree: Entity, + path: Arc, + cx: &AsyncApp, +) -> Task>> { + cx.spawn(async move |cx| { + let project_path = worktree.read_with(cx, |worktree, _cx| ProjectPath { + worktree_id: worktree.id(), + path, + })?; + + let buffer = project + .update(cx, |project, cx| project.open_buffer(project_path, cx))? + .await?; + + let mut parse_status = buffer.read_with(cx, |buffer, _cx| buffer.parse_status())?; + while *parse_status.borrow() != ParseStatus::Idle { + parse_status.changed().await?; + } + + Ok(buffer) + }) } pub async fn open_buffer_with_language_server( - project: &Entity, - worktree: &Entity, - path: &RelPath, + project: Entity, + worktree: Entity, + path: Arc, ready_languages: &mut HashSet, cx: &mut AsyncApp, ) -> Result<(Entity>, LanguageServerId, Entity)> { - let buffer = open_buffer(project, worktree, path, cx).await?; + let buffer = open_buffer(project.clone(), worktree, path.clone(), cx).await?; let (lsp_open_handle, path_style) = project.update(cx, |project, cx| { ( @@ -940,9 +1380,23 @@ fn main() { .await } Commands::RetrievalStats { + zeta2_args, worktree, - file_indexing_parallelism, - } => retrieval_stats(worktree, file_indexing_parallelism, app_state, cx).await, + extension, + limit, + skip, + } => { + retrieval_stats( + worktree, + app_state, + extension, + limit, + skip, + (&zeta2_args).to_options(false), + cx, + ) + .await + } }; match result { Ok(output) => {