declaration_scoring.rs

  1use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
  2use collections::HashMap;
  3use language::BufferSnapshot;
  4use ordered_float::OrderedFloat;
  5use project::ProjectEntryId;
  6use serde::Serialize;
  7use std::{cmp::Reverse, ops::Range, path::Path, sync::Arc};
  8use strum::EnumIter;
  9use text::{Point, ToPoint};
 10use util::RangeExt as _;
 11
 12use crate::{
 13    declaration::{CachedDeclarationPath, Declaration, Identifier},
 14    excerpt::EditPredictionExcerpt,
 15    imports::{Import, Imports, Module},
 16    reference::{Reference, ReferenceRegion},
 17    syntax_index::SyntaxIndexState,
 18    text_similarity::{
 19        IdentifierParts, OccurrenceSource, Occurrences, Similarity as _, WeightedSimilarity as _,
 20    },
 21};
 22
 23const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
 24
 25#[derive(Clone, Debug, PartialEq, Eq)]
 26pub struct EditPredictionScoreOptions {
 27    pub omit_excerpt_overlaps: bool,
 28}
 29
 30#[derive(Clone, Debug)]
 31pub struct ScoredDeclaration {
 32    /// identifier used by the local reference
 33    pub identifier: Identifier,
 34    pub declaration: Declaration,
 35    pub components: DeclarationScoreComponents,
 36}
 37
 38#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
 39pub enum DeclarationStyle {
 40    Signature,
 41    Declaration,
 42}
 43
 44#[derive(Clone, Debug, Serialize, Default)]
 45pub struct DeclarationScores {
 46    pub signature: f32,
 47    pub declaration: f32,
 48    pub retrieval: f32,
 49}
 50
 51impl ScoredDeclaration {
 52    /// Returns the score for this declaration with the specified style.
 53    pub fn score(&self, style: DeclarationStyle) -> f32 {
 54        // TODO: handle truncation
 55
 56        // Score related to how likely this is the correct declaration, range 0 to 1
 57        let retrieval = self.retrieval_score();
 58
 59        // Score related to the distance between the reference and cursor, range 0 to 1
 60        let distance_score = if self.components.is_referenced_nearby {
 61            1.0 / (1.0 + self.components.reference_line_distance as f32 / 10.0).powf(2.0)
 62        } else {
 63            // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
 64            0.5
 65        };
 66
 67        // For now instead of linear combination, the scores are just multiplied together.
 68        let combined_score = 10.0 * retrieval * distance_score;
 69
 70        match style {
 71            DeclarationStyle::Signature => {
 72                combined_score * self.components.excerpt_vs_signature_weighted_overlap
 73            }
 74            DeclarationStyle::Declaration => {
 75                2.0 * combined_score * self.components.excerpt_vs_item_weighted_overlap
 76            }
 77        }
 78    }
 79
 80    pub fn retrieval_score(&self) -> f32 {
 81        let mut score = if self.components.is_same_file {
 82            10.0 / self.components.same_file_declaration_count as f32
 83        } else if self.components.path_import_match_count > 0 {
 84            3.0
 85        } else if self.components.wildcard_path_import_match_count > 0 {
 86            1.0
 87        } else if self.components.normalized_import_similarity > 0.0 {
 88            self.components.normalized_import_similarity
 89        } else if self.components.normalized_wildcard_import_similarity > 0.0 {
 90            0.5 * self.components.normalized_wildcard_import_similarity
 91        } else {
 92            1.0 / self.components.declaration_count as f32
 93        };
 94        score *= 1. + self.components.included_by_others as f32 / 2.;
 95        score *= 1. + self.components.includes_others as f32 / 4.;
 96        score
 97    }
 98
 99    pub fn size(&self, style: DeclarationStyle) -> usize {
100        match &self.declaration {
101            Declaration::File { declaration, .. } => match style {
102                DeclarationStyle::Signature => declaration.signature_range.len(),
103                DeclarationStyle::Declaration => declaration.text.len(),
104            },
105            Declaration::Buffer { declaration, .. } => match style {
106                DeclarationStyle::Signature => declaration.signature_range.len(),
107                DeclarationStyle::Declaration => declaration.item_range.len(),
108            },
109        }
110    }
111
112    pub fn score_density(&self, style: DeclarationStyle) -> f32 {
113        self.score(style) / self.size(style) as f32
114    }
115}
116
117pub fn scored_declarations(
118    options: &EditPredictionScoreOptions,
119    index: &SyntaxIndexState,
120    excerpt: &EditPredictionExcerpt,
121    excerpt_occurrences: &Occurrences<IdentifierParts>,
122    adjacent_occurrences: &Occurrences<IdentifierParts>,
123    imports: &Imports,
124    identifier_to_references: HashMap<Identifier, Vec<Reference>>,
125    cursor_offset: usize,
126    current_buffer: &BufferSnapshot,
127) -> Vec<ScoredDeclaration> {
128    let cursor_point = cursor_offset.to_point(&current_buffer);
129
130    let mut wildcard_import_occurrences = Vec::new();
131    let mut wildcard_import_paths = Vec::new();
132    for wildcard_import in imports.wildcard_modules.iter() {
133        match wildcard_import {
134            Module::Namespace(namespace) => {
135                wildcard_import_occurrences.push(namespace.occurrences())
136            }
137            Module::SourceExact(path) => wildcard_import_paths.push(path),
138            Module::SourceFuzzy(path) => wildcard_import_occurrences.push(Occurrences::new(
139                IdentifierParts::occurrences_in_path(&path),
140            )),
141        }
142    }
143
144    let mut scored_declarations = Vec::new();
145    let mut project_entry_id_to_outline_ranges: HashMap<ProjectEntryId, Vec<Range<usize>>> =
146        HashMap::default();
147    for (identifier, references) in identifier_to_references {
148        let mut import_occurrences = Vec::new();
149        let mut import_paths = Vec::new();
150        let mut found_external_identifier: Option<&Identifier> = None;
151
152        if let Some(imports) = imports.identifier_to_imports.get(&identifier) {
153            // only use alias when it's the only import, could be generalized if some language
154            // has overlapping aliases
155            //
156            // TODO: when an aliased declaration is included in the prompt, should include the
157            // aliasing in the prompt.
158            //
159            // TODO: For SourceFuzzy consider having componentwise comparison that pays
160            // attention to ordering.
161            if let [
162                Import::Alias {
163                    module,
164                    external_identifier,
165                },
166            ] = imports.as_slice()
167            {
168                match module {
169                    Module::Namespace(namespace) => {
170                        import_occurrences.push(namespace.occurrences())
171                    }
172                    Module::SourceExact(path) => import_paths.push(path),
173                    Module::SourceFuzzy(path) => import_occurrences.push(Occurrences::new(
174                        IdentifierParts::occurrences_in_path(&path),
175                    )),
176                }
177                found_external_identifier = Some(&external_identifier);
178            } else {
179                for import in imports {
180                    match import {
181                        Import::Direct { module } => match module {
182                            Module::Namespace(namespace) => {
183                                import_occurrences.push(namespace.occurrences())
184                            }
185                            Module::SourceExact(path) => import_paths.push(path),
186                            Module::SourceFuzzy(path) => import_occurrences.push(Occurrences::new(
187                                IdentifierParts::occurrences_in_path(&path),
188                            )),
189                        },
190                        Import::Alias { .. } => {}
191                    }
192                }
193            }
194        }
195
196        let identifier_to_lookup = found_external_identifier.unwrap_or(&identifier);
197        // TODO: update this to be able to return more declarations? Especially if there is the
198        // ability to quickly filter a large list (based on imports)
199        let identifier_declarations = index
200            .declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier_to_lookup);
201        let declaration_count = identifier_declarations.len();
202
203        if declaration_count == 0 {
204            continue;
205        }
206
207        // TODO: option to filter out other candidates when same file / import match
208        let mut checked_declarations = Vec::with_capacity(declaration_count);
209        for (declaration_id, declaration) in identifier_declarations {
210            match declaration {
211                Declaration::Buffer {
212                    buffer_id,
213                    declaration: buffer_declaration,
214                    ..
215                } => {
216                    if buffer_id == &current_buffer.remote_id() {
217                        let already_included_in_prompt =
218                            range_intersection(&buffer_declaration.item_range, &excerpt.range)
219                                .is_some()
220                                || excerpt
221                                    .parent_declarations
222                                    .iter()
223                                    .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id);
224                        if !options.omit_excerpt_overlaps || !already_included_in_prompt {
225                            let declaration_line = buffer_declaration
226                                .item_range
227                                .start
228                                .to_point(current_buffer)
229                                .row;
230                            let declaration_line_distance =
231                                (cursor_point.row as i32 - declaration_line as i32).unsigned_abs();
232                            checked_declarations.push(CheckedDeclaration {
233                                declaration,
234                                same_file_line_distance: Some(declaration_line_distance),
235                                path_import_match_count: 0,
236                                wildcard_path_import_match_count: 0,
237                            });
238                        }
239                        continue;
240                    } else {
241                    }
242                }
243                Declaration::File { .. } => {}
244            }
245            let declaration_path = declaration.cached_path();
246            let path_import_match_count = import_paths
247                .iter()
248                .filter(|import_path| {
249                    declaration_path_matches_import(&declaration_path, import_path)
250                })
251                .count();
252            let wildcard_path_import_match_count = wildcard_import_paths
253                .iter()
254                .filter(|import_path| {
255                    declaration_path_matches_import(&declaration_path, import_path)
256                })
257                .count();
258            checked_declarations.push(CheckedDeclaration {
259                declaration,
260                same_file_line_distance: None,
261                path_import_match_count,
262                wildcard_path_import_match_count,
263            });
264        }
265
266        let mut max_import_similarity = 0.0;
267        let mut max_wildcard_import_similarity = 0.0;
268
269        let mut scored_declarations_for_identifier = Vec::with_capacity(checked_declarations.len());
270        for checked_declaration in checked_declarations {
271            let same_file_declaration_count =
272                index.file_declaration_count(checked_declaration.declaration);
273
274            let declaration = score_declaration(
275                &identifier,
276                &references,
277                checked_declaration,
278                same_file_declaration_count,
279                declaration_count,
280                &excerpt_occurrences,
281                &adjacent_occurrences,
282                &import_occurrences,
283                &wildcard_import_occurrences,
284                cursor_point,
285                current_buffer,
286            );
287
288            if declaration.components.import_similarity > max_import_similarity {
289                max_import_similarity = declaration.components.import_similarity;
290            }
291
292            if declaration.components.wildcard_import_similarity > max_wildcard_import_similarity {
293                max_wildcard_import_similarity = declaration.components.wildcard_import_similarity;
294            }
295
296            project_entry_id_to_outline_ranges
297                .entry(declaration.declaration.project_entry_id())
298                .or_default()
299                .push(declaration.declaration.item_range());
300            scored_declarations_for_identifier.push(declaration);
301        }
302
303        if max_import_similarity > 0.0 || max_wildcard_import_similarity > 0.0 {
304            for declaration in scored_declarations_for_identifier.iter_mut() {
305                if max_import_similarity > 0.0 {
306                    declaration.components.max_import_similarity = max_import_similarity;
307                    declaration.components.normalized_import_similarity =
308                        declaration.components.import_similarity / max_import_similarity;
309                }
310                if max_wildcard_import_similarity > 0.0 {
311                    declaration.components.normalized_wildcard_import_similarity =
312                        declaration.components.wildcard_import_similarity
313                            / max_wildcard_import_similarity;
314                }
315            }
316        }
317
318        scored_declarations.extend(scored_declarations_for_identifier);
319    }
320
321    // TODO: Inform this via import / retrieval scores of outline items
322    // TODO: Consider using a sweepline
323    for scored_declaration in scored_declarations.iter_mut() {
324        let project_entry_id = scored_declaration.declaration.project_entry_id();
325        let Some(ranges) = project_entry_id_to_outline_ranges.get(&project_entry_id) else {
326            continue;
327        };
328        for range in ranges {
329            if range.contains_inclusive(&scored_declaration.declaration.item_range()) {
330                scored_declaration.components.included_by_others += 1
331            } else if scored_declaration
332                .declaration
333                .item_range()
334                .contains_inclusive(range)
335            {
336                scored_declaration.components.includes_others += 1
337            }
338        }
339    }
340
341    scored_declarations.sort_unstable_by_key(|declaration| {
342        Reverse(OrderedFloat(
343            declaration.score(DeclarationStyle::Declaration),
344        ))
345    });
346
347    scored_declarations
348}
349
350struct CheckedDeclaration<'a> {
351    declaration: &'a Declaration,
352    same_file_line_distance: Option<u32>,
353    path_import_match_count: usize,
354    wildcard_path_import_match_count: usize,
355}
356
357fn declaration_path_matches_import(
358    declaration_path: &CachedDeclarationPath,
359    import_path: &Arc<Path>,
360) -> bool {
361    if import_path.is_absolute() {
362        declaration_path.equals_absolute_path(import_path)
363    } else {
364        declaration_path.ends_with_posix_path(import_path)
365    }
366}
367
368fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
369    let start = a.start.clone().max(b.start.clone());
370    let end = a.end.clone().min(b.end.clone());
371    if start < end {
372        Some(Range { start, end })
373    } else {
374        None
375    }
376}
377
378fn score_declaration(
379    identifier: &Identifier,
380    references: &[Reference],
381    checked_declaration: CheckedDeclaration,
382    same_file_declaration_count: usize,
383    declaration_count: usize,
384    excerpt_occurrences: &Occurrences<IdentifierParts>,
385    adjacent_occurrences: &Occurrences<IdentifierParts>,
386    import_occurrences: &[Occurrences<IdentifierParts>],
387    wildcard_import_occurrences: &[Occurrences<IdentifierParts>],
388    cursor: Point,
389    current_buffer: &BufferSnapshot,
390) -> ScoredDeclaration {
391    let CheckedDeclaration {
392        declaration,
393        same_file_line_distance,
394        path_import_match_count,
395        wildcard_path_import_match_count,
396    } = checked_declaration;
397
398    let is_referenced_nearby = references
399        .iter()
400        .any(|r| r.region == ReferenceRegion::Nearby);
401    let is_referenced_in_breadcrumb = references
402        .iter()
403        .any(|r| r.region == ReferenceRegion::Breadcrumb);
404    let reference_count = references.len();
405    let reference_line_distance = references
406        .iter()
407        .map(|r| {
408            let reference_line = r.range.start.to_point(current_buffer).row as i32;
409            (cursor.row as i32 - reference_line).unsigned_abs()
410        })
411        .min()
412        .unwrap();
413
414    let is_same_file = same_file_line_distance.is_some();
415    let declaration_line_distance = same_file_line_distance.unwrap_or(u32::MAX);
416
417    let item_source_occurrences = Occurrences::new(IdentifierParts::occurrences_in_str(
418        &declaration.item_text().0,
419    ));
420    let item_signature_occurrences = Occurrences::new(IdentifierParts::occurrences_in_str(
421        &declaration.signature_text().0,
422    ));
423    let excerpt_vs_item_jaccard = excerpt_occurrences.jaccard_similarity(&item_source_occurrences);
424    let excerpt_vs_signature_jaccard =
425        excerpt_occurrences.jaccard_similarity(&item_signature_occurrences);
426    let adjacent_vs_item_jaccard =
427        adjacent_occurrences.jaccard_similarity(&item_source_occurrences);
428    let adjacent_vs_signature_jaccard =
429        adjacent_occurrences.jaccard_similarity(&item_signature_occurrences);
430
431    let excerpt_vs_item_weighted_overlap =
432        excerpt_occurrences.weighted_overlap_coefficient(&item_source_occurrences);
433    let excerpt_vs_signature_weighted_overlap =
434        excerpt_occurrences.weighted_overlap_coefficient(&item_signature_occurrences);
435    let adjacent_vs_item_weighted_overlap =
436        adjacent_occurrences.weighted_overlap_coefficient(&item_source_occurrences);
437    let adjacent_vs_signature_weighted_overlap =
438        adjacent_occurrences.weighted_overlap_coefficient(&item_signature_occurrences);
439
440    let mut import_similarity = 0f32;
441    let mut wildcard_import_similarity = 0f32;
442    if !import_occurrences.is_empty() || !wildcard_import_occurrences.is_empty() {
443        let cached_path = declaration.cached_path();
444        let path_occurrences = Occurrences::new(IdentifierParts::occurrences_in_worktree_path(
445            cached_path
446                .worktree_abs_path
447                .file_name()
448                .map(|f| f.to_string_lossy()),
449            &cached_path.rel_path,
450        ));
451        import_similarity = import_occurrences
452            .iter()
453            .map(|namespace_occurrences| {
454                OrderedFloat(namespace_occurrences.jaccard_similarity(&path_occurrences))
455            })
456            .max()
457            .map(|similarity| similarity.into_inner())
458            .unwrap_or_default();
459
460        // TODO: Consider something other than max
461        wildcard_import_similarity = wildcard_import_occurrences
462            .iter()
463            .map(|namespace_occurrences| {
464                OrderedFloat(namespace_occurrences.jaccard_similarity(&path_occurrences))
465            })
466            .max()
467            .map(|similarity| similarity.into_inner())
468            .unwrap_or_default();
469    }
470
471    // TODO: Consider adding declaration_file_count
472    let score_components = DeclarationScoreComponents {
473        is_same_file,
474        is_referenced_nearby,
475        is_referenced_in_breadcrumb,
476        reference_line_distance,
477        declaration_line_distance,
478        reference_count,
479        same_file_declaration_count,
480        declaration_count,
481        excerpt_vs_item_jaccard,
482        excerpt_vs_signature_jaccard,
483        adjacent_vs_item_jaccard,
484        adjacent_vs_signature_jaccard,
485        excerpt_vs_item_weighted_overlap,
486        excerpt_vs_signature_weighted_overlap,
487        adjacent_vs_item_weighted_overlap,
488        adjacent_vs_signature_weighted_overlap,
489        path_import_match_count,
490        wildcard_path_import_match_count,
491        import_similarity,
492        max_import_similarity: 0.0,
493        normalized_import_similarity: 0.0,
494        wildcard_import_similarity,
495        normalized_wildcard_import_similarity: 0.0,
496        included_by_others: 0,
497        includes_others: 0,
498    };
499
500    ScoredDeclaration {
501        identifier: identifier.clone(),
502        declaration: declaration.clone(),
503        components: score_components,
504    }
505}
506
507#[cfg(test)]
508mod test {
509    use super::*;
510
511    #[test]
512    fn test_declaration_path_matches() {
513        let declaration_path =
514            CachedDeclarationPath::new_for_test("/home/user/project", "src/maths.ts");
515
516        assert!(declaration_path_matches_import(
517            &declaration_path,
518            &Path::new("maths.ts").into()
519        ));
520
521        assert!(declaration_path_matches_import(
522            &declaration_path,
523            &Path::new("project/src/maths.ts").into()
524        ));
525
526        assert!(declaration_path_matches_import(
527            &declaration_path,
528            &Path::new("user/project/src/maths.ts").into()
529        ));
530
531        assert!(declaration_path_matches_import(
532            &declaration_path,
533            &Path::new("/home/user/project/src/maths.ts").into()
534        ));
535
536        assert!(!declaration_path_matches_import(
537            &declaration_path,
538            &Path::new("other.ts").into()
539        ));
540
541        assert!(!declaration_path_matches_import(
542            &declaration_path,
543            &Path::new("/home/user/project/src/other.ts").into()
544        ));
545    }
546}