Add --prefilter-score-ratio

Michael Sloan and Agus created

Co-authored-by: Agus <agus@zed.dev>

Change summary

crates/edit_prediction_context/src/declaration_scoring.rs     | 27 ++++
crates/edit_prediction_context/src/edit_prediction_context.rs |  1 
crates/zeta2/src/zeta2.rs                                     |  1 
crates/zeta_cli/src/main.rs                                   |  3 
4 files changed, 27 insertions(+), 5 deletions(-)

Detailed changes

crates/edit_prediction_context/src/declaration_scoring.rs 🔗

@@ -19,9 +19,10 @@ use crate::{
 
 const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
 
-#[derive(Clone, Debug, PartialEq, Eq)]
+#[derive(Clone, Debug, PartialEq)]
 pub struct EditPredictionScoreOptions {
     pub omit_excerpt_overlaps: bool,
+    pub prefilter_score_ratio: f32,
 }
 
 #[derive(Clone, Debug)]
@@ -262,6 +263,8 @@ pub fn scored_declarations(
 
         let mut max_import_similarity = 0.0;
         let mut max_wildcard_import_similarity = 0.0;
+        // todo! consider max retrieval score instead?
+        let mut max_score = 0.0;
 
         let mut scored_declarations_for_identifier = Vec::with_capacity(checked_declarations.len());
         for checked_declaration in checked_declarations {
@@ -294,11 +297,19 @@ pub fn scored_declarations(
                 .entry(declaration.declaration.project_entry_id())
                 .or_default()
                 .push(declaration.declaration.item_range());
+            let score = declaration.score(DeclarationStyle::Declaration);
             scored_declarations_for_identifier.push(declaration);
+
+            if score > max_score {
+                max_score = score;
+            }
         }
 
-        if max_import_similarity > 0.0 || max_wildcard_import_similarity > 0.0 {
-            for declaration in scored_declarations_for_identifier.iter_mut() {
+        if max_import_similarity > 0.0
+            || max_wildcard_import_similarity > 0.0
+            || options.prefilter_score_ratio > 0.0
+        {
+            for mut declaration in scored_declarations_for_identifier.into_iter() {
                 if max_import_similarity > 0.0 {
                     declaration.components.max_import_similarity = max_import_similarity;
                     declaration.components.normalized_import_similarity =
@@ -309,10 +320,16 @@ pub fn scored_declarations(
                         declaration.components.wildcard_import_similarity
                             / max_wildcard_import_similarity;
                 }
+                if options.prefilter_score_ratio <= 0.0
+                    || declaration.score(DeclarationStyle::Declaration)
+                        > max_score * options.prefilter_score_ratio
+                {
+                    scored_declarations.push(declaration);
+                }
             }
+        } else {
+            scored_declarations.extend(scored_declarations_for_identifier);
         }
-
-        scored_declarations.extend(scored_declarations_for_identifier);
     }
 
     // TODO: Inform this via import / retrieval scores of outline items

crates/zeta2/src/zeta2.rs 🔗

@@ -52,6 +52,7 @@ pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPrediction
     },
     score: EditPredictionScoreOptions {
         omit_excerpt_overlaps: true,
+        prefilter_score_ratio: 0.5,
     },
 };
 

crates/zeta_cli/src/main.rs 🔗

@@ -112,6 +112,8 @@ struct Zeta2Args {
     file_indexing_parallelism: usize,
     #[arg(long, default_value_t = false)]
     disable_imports_gathering: bool,
+    #[arg(long, default_value_t = 0.5)]
+    prefilter_score_ratio: f32,
 }
 
 #[derive(clap::ValueEnum, Default, Debug, Clone)]
@@ -389,6 +391,7 @@ impl Zeta2Args {
                 },
                 score: EditPredictionScoreOptions {
                     omit_excerpt_overlaps,
+                    prefilter_score_ratio: self.prefilter_score_ratio,
                 },
             },
             max_diagnostic_bytes: self.max_diagnostic_bytes,