zeta2: Boost declarations included by others (#39975)

Agus Zubiaga and Michael Sloan created

Release Notes:

- N/A

Co-authored-by: Michael Sloan <michael@zed.dev>

Change summary

crates/cloud_llm_client/src/predict_edits_v3.rs           |   2 
crates/edit_prediction_context/src/declaration_scoring.rs | 365 ++++----
2 files changed, 196 insertions(+), 171 deletions(-)

Detailed changes

crates/cloud_llm_client/src/predict_edits_v3.rs 🔗

@@ -142,6 +142,8 @@ pub struct DeclarationScoreComponents {
     pub normalized_import_similarity: f32,
     pub wildcard_import_similarity: f32,
     pub normalized_wildcard_import_similarity: f32,
+    pub included_by_others: usize,
+    pub includes_others: usize,
 }
 
 #[derive(Debug, Clone, Serialize, Deserialize)]

crates/edit_prediction_context/src/declaration_scoring.rs 🔗

@@ -2,10 +2,12 @@ use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
 use collections::HashMap;
 use language::BufferSnapshot;
 use ordered_float::OrderedFloat;
+use project::ProjectEntryId;
 use serde::Serialize;
 use std::{cmp::Reverse, ops::Range, path::Path, sync::Arc};
 use strum::EnumIter;
 use text::{Point, ToPoint};
+use util::RangeExt as _;
 
 use crate::{
     CachedDeclarationPath, Declaration, EditPredictionExcerpt, Identifier,
@@ -73,7 +75,7 @@ impl ScoredDeclaration {
     }
 
     pub fn retrieval_score(&self) -> f32 {
-        if self.components.is_same_file {
+        let mut score = if self.components.is_same_file {
             10.0 / self.components.same_file_declaration_count as f32
         } else if self.components.path_import_match_count > 0 {
             3.0
@@ -85,7 +87,10 @@ impl ScoredDeclaration {
             0.5 * self.components.normalized_wildcard_import_similarity
         } else {
             1.0 / self.components.declaration_count as f32
-        }
+        };
+        score *= 1. + self.components.included_by_others as f32 / 2.;
+        score *= 1. + self.components.includes_others as f32 / 4.;
+        score
     }
 
     pub fn size(&self, style: DeclarationStyle) -> usize {
@@ -133,194 +138,210 @@ pub fn scored_declarations(
         }
     }
 
-    let mut declarations = identifier_to_references
-        .into_iter()
-        .flat_map(|(identifier, references)| {
-            let mut import_occurrences = Vec::new();
-            let mut import_paths = Vec::new();
-            let mut found_external_identifier: Option<&Identifier> = None;
-
-            if let Some(imports) = imports.identifier_to_imports.get(&identifier) {
-                // only use alias when it's the only import, could be generalized if some language
-                // has overlapping aliases
-                //
-                // TODO: when an aliased declaration is included in the prompt, should include the
-                // aliasing in the prompt.
-                //
-                // TODO: For SourceFuzzy consider having componentwise comparison that pays
-                // attention to ordering.
-                if let [
-                    Import::Alias {
-                        module,
-                        external_identifier,
-                    },
-                ] = imports.as_slice()
-                {
-                    match module {
-                        Module::Namespace(namespace) => {
-                            import_occurrences.push(namespace.occurrences())
-                        }
-                        Module::SourceExact(path) => import_paths.push(path),
-                        Module::SourceFuzzy(path) => {
-                            import_occurrences.push(Occurrences::from_path(&path))
-                        }
+    let mut scored_declarations = Vec::new();
+    let mut project_entry_id_to_outline_ranges: HashMap<ProjectEntryId, Vec<Range<usize>>> =
+        HashMap::default();
+    for (identifier, references) in identifier_to_references {
+        let mut import_occurrences = Vec::new();
+        let mut import_paths = Vec::new();
+        let mut found_external_identifier: Option<&Identifier> = None;
+
+        if let Some(imports) = imports.identifier_to_imports.get(&identifier) {
+            // only use alias when it's the only import, could be generalized if some language
+            // has overlapping aliases
+            //
+            // TODO: when an aliased declaration is included in the prompt, should include the
+            // aliasing in the prompt.
+            //
+            // TODO: For SourceFuzzy consider having componentwise comparison that pays
+            // attention to ordering.
+            if let [
+                Import::Alias {
+                    module,
+                    external_identifier,
+                },
+            ] = imports.as_slice()
+            {
+                match module {
+                    Module::Namespace(namespace) => {
+                        import_occurrences.push(namespace.occurrences())
                     }
-                    found_external_identifier = Some(&external_identifier);
-                } else {
-                    for import in imports {
-                        match import {
-                            Import::Direct { module } => match module {
-                                Module::Namespace(namespace) => {
-                                    import_occurrences.push(namespace.occurrences())
-                                }
-                                Module::SourceExact(path) => import_paths.push(path),
-                                Module::SourceFuzzy(path) => {
-                                    import_occurrences.push(Occurrences::from_path(&path))
-                                }
-                            },
-                            Import::Alias { .. } => {}
-                        }
+                    Module::SourceExact(path) => import_paths.push(path),
+                    Module::SourceFuzzy(path) => {
+                        import_occurrences.push(Occurrences::from_path(&path))
+                    }
+                }
+                found_external_identifier = Some(&external_identifier);
+            } else {
+                for import in imports {
+                    match import {
+                        Import::Direct { module } => match module {
+                            Module::Namespace(namespace) => {
+                                import_occurrences.push(namespace.occurrences())
+                            }
+                            Module::SourceExact(path) => import_paths.push(path),
+                            Module::SourceFuzzy(path) => {
+                                import_occurrences.push(Occurrences::from_path(&path))
+                            }
+                        },
+                        Import::Alias { .. } => {}
                     }
                 }
             }
+        }
 
-            let identifier_to_lookup = found_external_identifier.unwrap_or(&identifier);
-            // TODO: update this to be able to return more declarations? Especially if there is the
-            // ability to quickly filter a large list (based on imports)
-            let declarations = index
-                .declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(
-                    &identifier_to_lookup,
-                );
-            let declaration_count = declarations.len();
-
-            if declaration_count == 0 {
-                return Vec::new();
-            }
+        let identifier_to_lookup = found_external_identifier.unwrap_or(&identifier);
+        // TODO: update this to be able to return more declarations? Especially if there is the
+        // ability to quickly filter a large list (based on imports)
+        let identifier_declarations = index
+            .declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier_to_lookup);
+        let declaration_count = identifier_declarations.len();
 
-            // TODO: option to filter out other candidates when same file / import match
-            let mut checked_declarations = Vec::new();
-            for (declaration_id, declaration) in declarations {
-                match declaration {
-                    Declaration::Buffer {
-                        buffer_id,
-                        declaration: buffer_declaration,
-                        ..
-                    } => {
-                        if buffer_id == &current_buffer.remote_id() {
-                            let already_included_in_prompt =
-                                range_intersection(&buffer_declaration.item_range, &excerpt.range)
-                                    .is_some()
-                                    || excerpt.parent_declarations.iter().any(
-                                        |(excerpt_parent, _)| excerpt_parent == &declaration_id,
-                                    );
-                            if !options.omit_excerpt_overlaps || !already_included_in_prompt {
-                                let declaration_line = buffer_declaration
-                                    .item_range
-                                    .start
-                                    .to_point(current_buffer)
-                                    .row;
-                                let declaration_line_distance = (cursor_point.row as i32
-                                    - declaration_line as i32)
-                                    .unsigned_abs();
-                                checked_declarations.push(CheckedDeclaration {
-                                    declaration,
-                                    same_file_line_distance: Some(declaration_line_distance),
-                                    path_import_match_count: 0,
-                                    wildcard_path_import_match_count: 0,
-                                });
-                            }
-                            continue;
-                        } else {
+        if declaration_count == 0 {
+            continue;
+        }
+
+        // TODO: option to filter out other candidates when same file / import match
+        let mut checked_declarations = Vec::with_capacity(declaration_count);
+        for (declaration_id, declaration) in identifier_declarations {
+            match declaration {
+                Declaration::Buffer {
+                    buffer_id,
+                    declaration: buffer_declaration,
+                    ..
+                } => {
+                    if buffer_id == &current_buffer.remote_id() {
+                        let already_included_in_prompt =
+                            range_intersection(&buffer_declaration.item_range, &excerpt.range)
+                                .is_some()
+                                || excerpt
+                                    .parent_declarations
+                                    .iter()
+                                    .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id);
+                        if !options.omit_excerpt_overlaps || !already_included_in_prompt {
+                            let declaration_line = buffer_declaration
+                                .item_range
+                                .start
+                                .to_point(current_buffer)
+                                .row;
+                            let declaration_line_distance =
+                                (cursor_point.row as i32 - declaration_line as i32).unsigned_abs();
+                            checked_declarations.push(CheckedDeclaration {
+                                declaration,
+                                same_file_line_distance: Some(declaration_line_distance),
+                                path_import_match_count: 0,
+                                wildcard_path_import_match_count: 0,
+                            });
                         }
+                        continue;
+                    } else {
                     }
-                    Declaration::File { .. } => {}
                 }
-                let declaration_path = declaration.cached_path();
-                let path_import_match_count = import_paths
-                    .iter()
-                    .filter(|import_path| {
-                        declaration_path_matches_import(&declaration_path, import_path)
-                    })
-                    .count();
-                let wildcard_path_import_match_count = wildcard_import_paths
-                    .iter()
-                    .filter(|import_path| {
-                        declaration_path_matches_import(&declaration_path, import_path)
-                    })
-                    .count();
-                checked_declarations.push(CheckedDeclaration {
-                    declaration,
-                    same_file_line_distance: None,
-                    path_import_match_count,
-                    wildcard_path_import_match_count,
-                });
+                Declaration::File { .. } => {}
             }
+            let declaration_path = declaration.cached_path();
+            let path_import_match_count = import_paths
+                .iter()
+                .filter(|import_path| {
+                    declaration_path_matches_import(&declaration_path, import_path)
+                })
+                .count();
+            let wildcard_path_import_match_count = wildcard_import_paths
+                .iter()
+                .filter(|import_path| {
+                    declaration_path_matches_import(&declaration_path, import_path)
+                })
+                .count();
+            checked_declarations.push(CheckedDeclaration {
+                declaration,
+                same_file_line_distance: None,
+                path_import_match_count,
+                wildcard_path_import_match_count,
+            });
+        }
 
-            let mut max_import_similarity = 0.0;
-            let mut max_wildcard_import_similarity = 0.0;
-
-            let mut scored_declarations_for_identifier = checked_declarations
-                .into_iter()
-                .map(|checked_declaration| {
-                    let same_file_declaration_count =
-                        index.file_declaration_count(checked_declaration.declaration);
-
-                    let declaration = score_declaration(
-                        &identifier,
-                        &references,
-                        checked_declaration,
-                        same_file_declaration_count,
-                        declaration_count,
-                        &excerpt_occurrences,
-                        &adjacent_occurrences,
-                        &import_occurrences,
-                        &wildcard_import_occurrences,
-                        cursor_point,
-                        current_buffer,
-                    );
-
-                    if declaration.components.import_similarity > max_import_similarity {
-                        max_import_similarity = declaration.components.import_similarity;
-                    }
+        let mut max_import_similarity = 0.0;
+        let mut max_wildcard_import_similarity = 0.0;
+
+        let mut scored_declarations_for_identifier = Vec::with_capacity(checked_declarations.len());
+        for checked_declaration in checked_declarations {
+            let same_file_declaration_count =
+                index.file_declaration_count(checked_declaration.declaration);
+
+            let declaration = score_declaration(
+                &identifier,
+                &references,
+                checked_declaration,
+                same_file_declaration_count,
+                declaration_count,
+                &excerpt_occurrences,
+                &adjacent_occurrences,
+                &import_occurrences,
+                &wildcard_import_occurrences,
+                cursor_point,
+                current_buffer,
+            );
+
+            if declaration.components.import_similarity > max_import_similarity {
+                max_import_similarity = declaration.components.import_similarity;
+            }
 
-                    if declaration.components.wildcard_import_similarity
-                        > max_wildcard_import_similarity
-                    {
-                        max_wildcard_import_similarity =
-                            declaration.components.wildcard_import_similarity;
-                    }
+            if declaration.components.wildcard_import_similarity > max_wildcard_import_similarity {
+                max_wildcard_import_similarity = declaration.components.wildcard_import_similarity;
+            }
 
-                    declaration
-                })
-                .collect::<Vec<_>>();
-
-            if max_import_similarity > 0.0 || max_wildcard_import_similarity > 0.0 {
-                for declaration in scored_declarations_for_identifier.iter_mut() {
-                    if max_import_similarity > 0.0 {
-                        declaration.components.max_import_similarity = max_import_similarity;
-                        declaration.components.normalized_import_similarity =
-                            declaration.components.import_similarity / max_import_similarity;
-                    }
-                    if max_wildcard_import_similarity > 0.0 {
-                        declaration.components.normalized_wildcard_import_similarity =
-                            declaration.components.wildcard_import_similarity
-                                / max_wildcard_import_similarity;
-                    }
+            project_entry_id_to_outline_ranges
+                .entry(declaration.declaration.project_entry_id())
+                .or_default()
+                .push(declaration.declaration.item_range());
+            scored_declarations_for_identifier.push(declaration);
+        }
+
+        if max_import_similarity > 0.0 || max_wildcard_import_similarity > 0.0 {
+            for declaration in scored_declarations_for_identifier.iter_mut() {
+                if max_import_similarity > 0.0 {
+                    declaration.components.max_import_similarity = max_import_similarity;
+                    declaration.components.normalized_import_similarity =
+                        declaration.components.import_similarity / max_import_similarity;
+                }
+                if max_wildcard_import_similarity > 0.0 {
+                    declaration.components.normalized_wildcard_import_similarity =
+                        declaration.components.wildcard_import_similarity
+                            / max_wildcard_import_similarity;
                 }
             }
+        }
 
-            scored_declarations_for_identifier
-        })
-        .collect::<Vec<_>>();
+        scored_declarations.extend(scored_declarations_for_identifier);
+    }
+
+    // TODO: Inform this via import / retrieval scores of outline items
+    // TODO: Consider using a sweepline
+    for scored_declaration in scored_declarations.iter_mut() {
+        let project_entry_id = scored_declaration.declaration.project_entry_id();
+        let Some(ranges) = project_entry_id_to_outline_ranges.get(&project_entry_id) else {
+            continue;
+        };
+        for range in ranges {
+            if range.contains_inclusive(&scored_declaration.declaration.item_range()) {
+                scored_declaration.components.included_by_others += 1
+            } else if scored_declaration
+                .declaration
+                .item_range()
+                .contains_inclusive(range)
+            {
+                scored_declaration.components.includes_others += 1
+            }
+        }
+    }
 
-    declarations.sort_unstable_by_key(|declaration| {
-        let score_density = declaration
-            .score_density(DeclarationStyle::Declaration)
-            .max(declaration.score_density(DeclarationStyle::Signature));
-        Reverse(OrderedFloat(score_density))
+    scored_declarations.sort_unstable_by_key(|declaration| {
+        Reverse(OrderedFloat(
+            declaration.score(DeclarationStyle::Declaration),
+        ))
     });
 
-    declarations
+    scored_declarations
 }
 
 struct CheckedDeclaration<'a> {
@@ -465,6 +486,8 @@ fn score_declaration(
         normalized_import_similarity: 0.0,
         wildcard_import_similarity,
         normalized_wildcard_import_similarity: 0.0,
+        included_by_others: 0,
+        includes_others: 0,
     };
 
     ScoredDeclaration {