diff --git a/crates/edit_prediction_context/src/declaration_scoring.rs b/crates/edit_prediction_context/src/declaration_scoring.rs index 48a823362769770c836b44e7d8a6c1942d3a1196..511f053263fb116cb5e5d9ca92cbe9b81da82cc7 100644 --- a/crates/edit_prediction_context/src/declaration_scoring.rs +++ b/crates/edit_prediction_context/src/declaration_scoring.rs @@ -19,9 +19,10 @@ use crate::{ const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16; -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, PartialEq)] pub struct EditPredictionScoreOptions { pub omit_excerpt_overlaps: bool, + pub prefilter_score_ratio: f32, } #[derive(Clone, Debug)] @@ -262,6 +263,8 @@ pub fn scored_declarations( let mut max_import_similarity = 0.0; let mut max_wildcard_import_similarity = 0.0; + // todo! consider max retrieval score instead? + let mut max_score = 0.0; let mut scored_declarations_for_identifier = Vec::with_capacity(checked_declarations.len()); for checked_declaration in checked_declarations { @@ -294,11 +297,19 @@ pub fn scored_declarations( .entry(declaration.declaration.project_entry_id()) .or_default() .push(declaration.declaration.item_range()); + let score = declaration.score(DeclarationStyle::Declaration); scored_declarations_for_identifier.push(declaration); + + if score > max_score { + max_score = score; + } } - if max_import_similarity > 0.0 || max_wildcard_import_similarity > 0.0 { - for declaration in scored_declarations_for_identifier.iter_mut() { + if max_import_similarity > 0.0 + || max_wildcard_import_similarity > 0.0 + || options.prefilter_score_ratio > 0.0 + { + for mut declaration in scored_declarations_for_identifier.into_iter() { if max_import_similarity > 0.0 { declaration.components.max_import_similarity = max_import_similarity; declaration.components.normalized_import_similarity = @@ -309,10 +320,16 @@ pub fn scored_declarations( declaration.components.wildcard_import_similarity / max_wildcard_import_similarity; } + if options.prefilter_score_ratio <= 0.0 + || declaration.score(DeclarationStyle::Declaration) + > max_score * options.prefilter_score_ratio + { + scored_declarations.push(declaration); + } } + } else { + scored_declarations.extend(scored_declarations_for_identifier); } - - scored_declarations.extend(scored_declarations_for_identifier); } // TODO: Inform this via import / retrieval scores of outline items diff --git a/crates/edit_prediction_context/src/edit_prediction_context.rs b/crates/edit_prediction_context/src/edit_prediction_context.rs index 19cafe0412bb0db67ef906d1ff119d7c23234f78..2d7e675019f3fa46251a8e82cb0ef1bd024d9493 100644 --- a/crates/edit_prediction_context/src/edit_prediction_context.rs +++ b/crates/edit_prediction_context/src/edit_prediction_context.rs @@ -201,6 +201,7 @@ mod tests { }, score: EditPredictionScoreOptions { omit_excerpt_overlaps: true, + prefilter_score_ratio: 0.0, }, }, Some(index.clone()), diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index e4789aa085a27ca11e443c84f487b9f7c2c82538..52bbab15173e48b6a51ac34f625c4636406d6113 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -52,6 +52,7 @@ pub const DEFAULT_CONTEXT_OPTIONS: EditPredictionContextOptions = EditPrediction }, score: EditPredictionScoreOptions { omit_excerpt_overlaps: true, + prefilter_score_ratio: 0.5, }, }; diff --git a/crates/zeta_cli/src/main.rs b/crates/zeta_cli/src/main.rs index 236c5eb4572cf451a3efd435b9d0ad20d4380b72..1e3425122481df86fa4019d5b73b46aa2ccaf59c 100644 --- a/crates/zeta_cli/src/main.rs +++ b/crates/zeta_cli/src/main.rs @@ -112,6 +112,8 @@ struct Zeta2Args { file_indexing_parallelism: usize, #[arg(long, default_value_t = false)] disable_imports_gathering: bool, + #[arg(long, default_value_t = 0.5)] + prefilter_score_ratio: f32, } #[derive(clap::ValueEnum, Default, Debug, Clone)] @@ -389,6 +391,7 @@ impl Zeta2Args { }, score: EditPredictionScoreOptions { omit_excerpt_overlaps, + prefilter_score_ratio: self.prefilter_score_ratio, }, }, max_diagnostic_bytes: self.max_diagnostic_bytes,