declaration_scoring.rs

  1use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
  2use collections::HashMap;
  3use language::BufferSnapshot;
  4use ordered_float::OrderedFloat;
  5use project::ProjectEntryId;
  6use serde::Serialize;
  7use std::{cmp::Reverse, ops::Range, path::Path, sync::Arc};
  8use strum::EnumIter;
  9use text::{Point, ToPoint};
 10use util::RangeExt as _;
 11
 12use crate::{
 13    CachedDeclarationPath, Declaration, EditPredictionExcerpt, Identifier,
 14    imports::{Import, Imports, Module},
 15    reference::{Reference, ReferenceRegion},
 16    syntax_index::SyntaxIndexState,
 17    text_similarity::{Occurrences, jaccard_similarity, weighted_overlap_coefficient},
 18};
 19
 20const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
 21
 22#[derive(Clone, Debug, PartialEq)]
 23pub struct EditPredictionScoreOptions {
 24    pub omit_excerpt_overlaps: bool,
 25    pub prefilter_score_ratio: f32,
 26}
 27
 28#[derive(Clone, Debug)]
 29pub struct ScoredDeclaration {
 30    /// identifier used by the local reference
 31    pub identifier: Identifier,
 32    pub declaration: Declaration,
 33    pub components: DeclarationScoreComponents,
 34}
 35
 36#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
 37pub enum DeclarationStyle {
 38    Signature,
 39    Declaration,
 40}
 41
 42#[derive(Clone, Debug, Serialize, Default)]
 43pub struct DeclarationScores {
 44    pub signature: f32,
 45    pub declaration: f32,
 46    pub retrieval: f32,
 47}
 48
 49impl ScoredDeclaration {
 50    /// Returns the score for this declaration with the specified style.
 51    pub fn score(&self, style: DeclarationStyle) -> f32 {
 52        // TODO: handle truncation
 53
 54        // Score related to how likely this is the correct declaration, range 0 to 1
 55        let retrieval = self.retrieval_score();
 56
 57        // Score related to the distance between the reference and cursor, range 0 to 1
 58        let distance_score = if self.components.is_referenced_nearby {
 59            1.0 / (1.0 + self.components.reference_line_distance as f32 / 10.0).powf(2.0)
 60        } else {
 61            // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
 62            0.5
 63        };
 64
 65        // For now instead of linear combination, the scores are just multiplied together.
 66        let combined_score = 10.0 * retrieval * distance_score;
 67
 68        match style {
 69            DeclarationStyle::Signature => {
 70                combined_score * self.components.excerpt_vs_signature_weighted_overlap
 71            }
 72            DeclarationStyle::Declaration => {
 73                2.0 * combined_score * self.components.excerpt_vs_item_weighted_overlap
 74            }
 75        }
 76    }
 77
 78    pub fn retrieval_score(&self) -> f32 {
 79        let mut score = if self.components.is_same_file {
 80            10.0 / self.components.same_file_declaration_count as f32
 81        } else if self.components.path_import_match_count > 0 {
 82            3.0
 83        } else if self.components.wildcard_path_import_match_count > 0 {
 84            1.0
 85        } else if self.components.normalized_import_similarity > 0.0 {
 86            self.components.normalized_import_similarity
 87        } else if self.components.normalized_wildcard_import_similarity > 0.0 {
 88            0.5 * self.components.normalized_wildcard_import_similarity
 89        } else {
 90            1.0 / self.components.declaration_count as f32
 91        };
 92        score *= 1. + self.components.included_by_others as f32 / 2.;
 93        score *= 1. + self.components.includes_others as f32 / 4.;
 94        score
 95    }
 96
 97    pub fn size(&self, style: DeclarationStyle) -> usize {
 98        match &self.declaration {
 99            Declaration::File { declaration, .. } => match style {
100                DeclarationStyle::Signature => declaration.signature_range.len(),
101                DeclarationStyle::Declaration => declaration.text.len(),
102            },
103            Declaration::Buffer { declaration, .. } => match style {
104                DeclarationStyle::Signature => declaration.signature_range.len(),
105                DeclarationStyle::Declaration => declaration.item_range.len(),
106            },
107        }
108    }
109
110    pub fn score_density(&self, style: DeclarationStyle) -> f32 {
111        self.score(style) / self.size(style) as f32
112    }
113}
114
115pub fn scored_declarations(
116    options: &EditPredictionScoreOptions,
117    index: &SyntaxIndexState,
118    excerpt: &EditPredictionExcerpt,
119    excerpt_occurrences: &Occurrences,
120    adjacent_occurrences: &Occurrences,
121    imports: &Imports,
122    identifier_to_references: HashMap<Identifier, Vec<Reference>>,
123    cursor_offset: usize,
124    current_buffer: &BufferSnapshot,
125) -> Vec<ScoredDeclaration> {
126    let cursor_point = cursor_offset.to_point(&current_buffer);
127
128    let mut wildcard_import_occurrences = Vec::new();
129    let mut wildcard_import_paths = Vec::new();
130    for wildcard_import in imports.wildcard_modules.iter() {
131        match wildcard_import {
132            Module::Namespace(namespace) => {
133                wildcard_import_occurrences.push(namespace.occurrences())
134            }
135            Module::SourceExact(path) => wildcard_import_paths.push(path),
136            Module::SourceFuzzy(path) => {
137                wildcard_import_occurrences.push(Occurrences::from_path(&path))
138            }
139        }
140    }
141
142    let mut scored_declarations = Vec::new();
143    let mut project_entry_id_to_outline_ranges: HashMap<ProjectEntryId, Vec<Range<usize>>> =
144        HashMap::default();
145    for (identifier, references) in identifier_to_references {
146        let mut import_occurrences = Vec::new();
147        let mut import_paths = Vec::new();
148        let mut found_external_identifier: Option<&Identifier> = None;
149
150        if let Some(imports) = imports.identifier_to_imports.get(&identifier) {
151            // only use alias when it's the only import, could be generalized if some language
152            // has overlapping aliases
153            //
154            // TODO: when an aliased declaration is included in the prompt, should include the
155            // aliasing in the prompt.
156            //
157            // TODO: For SourceFuzzy consider having componentwise comparison that pays
158            // attention to ordering.
159            if let [
160                Import::Alias {
161                    module,
162                    external_identifier,
163                },
164            ] = imports.as_slice()
165            {
166                match module {
167                    Module::Namespace(namespace) => {
168                        import_occurrences.push(namespace.occurrences())
169                    }
170                    Module::SourceExact(path) => import_paths.push(path),
171                    Module::SourceFuzzy(path) => {
172                        import_occurrences.push(Occurrences::from_path(&path))
173                    }
174                }
175                found_external_identifier = Some(&external_identifier);
176            } else {
177                for import in imports {
178                    match import {
179                        Import::Direct { module } => match module {
180                            Module::Namespace(namespace) => {
181                                import_occurrences.push(namespace.occurrences())
182                            }
183                            Module::SourceExact(path) => import_paths.push(path),
184                            Module::SourceFuzzy(path) => {
185                                import_occurrences.push(Occurrences::from_path(&path))
186                            }
187                        },
188                        Import::Alias { .. } => {}
189                    }
190                }
191            }
192        }
193
194        let identifier_to_lookup = found_external_identifier.unwrap_or(&identifier);
195        // TODO: update this to be able to return more declarations? Especially if there is the
196        // ability to quickly filter a large list (based on imports)
197        let identifier_declarations = index
198            .declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier_to_lookup);
199        let declaration_count = identifier_declarations.len();
200
201        if declaration_count == 0 {
202            continue;
203        }
204
205        // TODO: option to filter out other candidates when same file / import match
206        let mut checked_declarations = Vec::with_capacity(declaration_count);
207        for (declaration_id, declaration) in identifier_declarations {
208            match declaration {
209                Declaration::Buffer {
210                    buffer_id,
211                    declaration: buffer_declaration,
212                    ..
213                } => {
214                    if buffer_id == &current_buffer.remote_id() {
215                        let already_included_in_prompt =
216                            range_intersection(&buffer_declaration.item_range, &excerpt.range)
217                                .is_some()
218                                || excerpt
219                                    .parent_declarations
220                                    .iter()
221                                    .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id);
222                        if !options.omit_excerpt_overlaps || !already_included_in_prompt {
223                            let declaration_line = buffer_declaration
224                                .item_range
225                                .start
226                                .to_point(current_buffer)
227                                .row;
228                            let declaration_line_distance =
229                                (cursor_point.row as i32 - declaration_line as i32).unsigned_abs();
230                            checked_declarations.push(CheckedDeclaration {
231                                declaration,
232                                same_file_line_distance: Some(declaration_line_distance),
233                                path_import_match_count: 0,
234                                wildcard_path_import_match_count: 0,
235                            });
236                        }
237                        continue;
238                    } else {
239                    }
240                }
241                Declaration::File { .. } => {}
242            }
243            let declaration_path = declaration.cached_path();
244            let path_import_match_count = import_paths
245                .iter()
246                .filter(|import_path| {
247                    declaration_path_matches_import(&declaration_path, import_path)
248                })
249                .count();
250            let wildcard_path_import_match_count = wildcard_import_paths
251                .iter()
252                .filter(|import_path| {
253                    declaration_path_matches_import(&declaration_path, import_path)
254                })
255                .count();
256            checked_declarations.push(CheckedDeclaration {
257                declaration,
258                same_file_line_distance: None,
259                path_import_match_count,
260                wildcard_path_import_match_count,
261            });
262        }
263
264        let mut max_import_similarity = 0.0;
265        let mut max_wildcard_import_similarity = 0.0;
266        // todo! consider max retrieval score instead?
267        let mut max_score = 0.0;
268
269        let mut scored_declarations_for_identifier = Vec::with_capacity(checked_declarations.len());
270        for checked_declaration in checked_declarations {
271            let same_file_declaration_count =
272                index.file_declaration_count(checked_declaration.declaration);
273
274            let declaration = score_declaration(
275                &identifier,
276                &references,
277                checked_declaration,
278                same_file_declaration_count,
279                declaration_count,
280                &excerpt_occurrences,
281                &adjacent_occurrences,
282                &import_occurrences,
283                &wildcard_import_occurrences,
284                cursor_point,
285                current_buffer,
286            );
287
288            if declaration.components.import_similarity > max_import_similarity {
289                max_import_similarity = declaration.components.import_similarity;
290            }
291
292            if declaration.components.wildcard_import_similarity > max_wildcard_import_similarity {
293                max_wildcard_import_similarity = declaration.components.wildcard_import_similarity;
294            }
295
296            project_entry_id_to_outline_ranges
297                .entry(declaration.declaration.project_entry_id())
298                .or_default()
299                .push(declaration.declaration.item_range());
300            let score = declaration.score(DeclarationStyle::Declaration);
301            scored_declarations_for_identifier.push(declaration);
302
303            if score > max_score {
304                max_score = score;
305            }
306        }
307
308        if max_import_similarity > 0.0
309            || max_wildcard_import_similarity > 0.0
310            || options.prefilter_score_ratio > 0.0
311        {
312            for mut declaration in scored_declarations_for_identifier.into_iter() {
313                if max_import_similarity > 0.0 {
314                    declaration.components.max_import_similarity = max_import_similarity;
315                    declaration.components.normalized_import_similarity =
316                        declaration.components.import_similarity / max_import_similarity;
317                }
318                if max_wildcard_import_similarity > 0.0 {
319                    declaration.components.normalized_wildcard_import_similarity =
320                        declaration.components.wildcard_import_similarity
321                            / max_wildcard_import_similarity;
322                }
323                if options.prefilter_score_ratio <= 0.0
324                    || declaration.score(DeclarationStyle::Declaration)
325                        > max_score * options.prefilter_score_ratio
326                {
327                    scored_declarations.push(declaration);
328                }
329            }
330        } else {
331            scored_declarations.extend(scored_declarations_for_identifier);
332        }
333    }
334
335    // TODO: Inform this via import / retrieval scores of outline items
336    // TODO: Consider using a sweepline
337    for scored_declaration in scored_declarations.iter_mut() {
338        let project_entry_id = scored_declaration.declaration.project_entry_id();
339        let Some(ranges) = project_entry_id_to_outline_ranges.get(&project_entry_id) else {
340            continue;
341        };
342        for range in ranges {
343            if range.contains_inclusive(&scored_declaration.declaration.item_range()) {
344                scored_declaration.components.included_by_others += 1
345            } else if scored_declaration
346                .declaration
347                .item_range()
348                .contains_inclusive(range)
349            {
350                scored_declaration.components.includes_others += 1
351            }
352        }
353    }
354
355    scored_declarations.sort_unstable_by_key(|declaration| {
356        Reverse(OrderedFloat(
357            declaration.score(DeclarationStyle::Declaration),
358        ))
359    });
360
361    scored_declarations
362}
363
364struct CheckedDeclaration<'a> {
365    declaration: &'a Declaration,
366    same_file_line_distance: Option<u32>,
367    path_import_match_count: usize,
368    wildcard_path_import_match_count: usize,
369}
370
371fn declaration_path_matches_import(
372    declaration_path: &CachedDeclarationPath,
373    import_path: &Arc<Path>,
374) -> bool {
375    if import_path.is_absolute() {
376        declaration_path.equals_absolute_path(import_path)
377    } else {
378        declaration_path.ends_with_posix_path(import_path)
379    }
380}
381
382fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
383    let start = a.start.clone().max(b.start.clone());
384    let end = a.end.clone().min(b.end.clone());
385    if start < end {
386        Some(Range { start, end })
387    } else {
388        None
389    }
390}
391
392fn score_declaration(
393    identifier: &Identifier,
394    references: &[Reference],
395    checked_declaration: CheckedDeclaration,
396    same_file_declaration_count: usize,
397    declaration_count: usize,
398    excerpt_occurrences: &Occurrences,
399    adjacent_occurrences: &Occurrences,
400    import_occurrences: &[Occurrences],
401    wildcard_import_occurrences: &[Occurrences],
402    cursor: Point,
403    current_buffer: &BufferSnapshot,
404) -> ScoredDeclaration {
405    let CheckedDeclaration {
406        declaration,
407        same_file_line_distance,
408        path_import_match_count,
409        wildcard_path_import_match_count,
410    } = checked_declaration;
411
412    let is_referenced_nearby = references
413        .iter()
414        .any(|r| r.region == ReferenceRegion::Nearby);
415    let is_referenced_in_breadcrumb = references
416        .iter()
417        .any(|r| r.region == ReferenceRegion::Breadcrumb);
418    let reference_count = references.len();
419    let reference_line_distance = references
420        .iter()
421        .map(|r| {
422            let reference_line = r.range.start.to_point(current_buffer).row as i32;
423            (cursor.row as i32 - reference_line).unsigned_abs()
424        })
425        .min()
426        .unwrap();
427
428    let is_same_file = same_file_line_distance.is_some();
429    let declaration_line_distance = same_file_line_distance.unwrap_or(u32::MAX);
430
431    let item_source_occurrences = Occurrences::within_string(&declaration.item_text().0);
432    let item_signature_occurrences = Occurrences::within_string(&declaration.signature_text().0);
433    let excerpt_vs_item_jaccard = jaccard_similarity(excerpt_occurrences, &item_source_occurrences);
434    let excerpt_vs_signature_jaccard =
435        jaccard_similarity(excerpt_occurrences, &item_signature_occurrences);
436    let adjacent_vs_item_jaccard =
437        jaccard_similarity(adjacent_occurrences, &item_source_occurrences);
438    let adjacent_vs_signature_jaccard =
439        jaccard_similarity(adjacent_occurrences, &item_signature_occurrences);
440
441    let excerpt_vs_item_weighted_overlap =
442        weighted_overlap_coefficient(excerpt_occurrences, &item_source_occurrences);
443    let excerpt_vs_signature_weighted_overlap =
444        weighted_overlap_coefficient(excerpt_occurrences, &item_signature_occurrences);
445    let adjacent_vs_item_weighted_overlap =
446        weighted_overlap_coefficient(adjacent_occurrences, &item_source_occurrences);
447    let adjacent_vs_signature_weighted_overlap =
448        weighted_overlap_coefficient(adjacent_occurrences, &item_signature_occurrences);
449
450    let mut import_similarity = 0f32;
451    let mut wildcard_import_similarity = 0f32;
452    if !import_occurrences.is_empty() || !wildcard_import_occurrences.is_empty() {
453        let cached_path = declaration.cached_path();
454        let path_occurrences = Occurrences::from_worktree_path(
455            cached_path
456                .worktree_abs_path
457                .file_name()
458                .map(|f| f.to_string_lossy()),
459            &cached_path.rel_path,
460        );
461        import_similarity = import_occurrences
462            .iter()
463            .map(|namespace_occurrences| {
464                OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences))
465            })
466            .max()
467            .map(|similarity| similarity.into_inner())
468            .unwrap_or_default();
469
470        // TODO: Consider something other than max
471        wildcard_import_similarity = wildcard_import_occurrences
472            .iter()
473            .map(|namespace_occurrences| {
474                OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences))
475            })
476            .max()
477            .map(|similarity| similarity.into_inner())
478            .unwrap_or_default();
479    }
480
481    // TODO: Consider adding declaration_file_count
482    let score_components = DeclarationScoreComponents {
483        is_same_file,
484        is_referenced_nearby,
485        is_referenced_in_breadcrumb,
486        reference_line_distance,
487        declaration_line_distance,
488        reference_count,
489        same_file_declaration_count,
490        declaration_count,
491        excerpt_vs_item_jaccard,
492        excerpt_vs_signature_jaccard,
493        adjacent_vs_item_jaccard,
494        adjacent_vs_signature_jaccard,
495        excerpt_vs_item_weighted_overlap,
496        excerpt_vs_signature_weighted_overlap,
497        adjacent_vs_item_weighted_overlap,
498        adjacent_vs_signature_weighted_overlap,
499        path_import_match_count,
500        wildcard_path_import_match_count,
501        import_similarity,
502        max_import_similarity: 0.0,
503        normalized_import_similarity: 0.0,
504        wildcard_import_similarity,
505        normalized_wildcard_import_similarity: 0.0,
506        included_by_others: 0,
507        includes_others: 0,
508    };
509
510    ScoredDeclaration {
511        identifier: identifier.clone(),
512        declaration: declaration.clone(),
513        components: score_components,
514    }
515}
516
517#[cfg(test)]
518mod test {
519    use super::*;
520
521    #[test]
522    fn test_declaration_path_matches() {
523        let declaration_path =
524            CachedDeclarationPath::new_for_test("/home/user/project", "src/maths.ts");
525
526        assert!(declaration_path_matches_import(
527            &declaration_path,
528            &Path::new("maths.ts").into()
529        ));
530
531        assert!(declaration_path_matches_import(
532            &declaration_path,
533            &Path::new("project/src/maths.ts").into()
534        ));
535
536        assert!(declaration_path_matches_import(
537            &declaration_path,
538            &Path::new("user/project/src/maths.ts").into()
539        ));
540
541        assert!(declaration_path_matches_import(
542            &declaration_path,
543            &Path::new("/home/user/project/src/maths.ts").into()
544        ));
545
546        assert!(!declaration_path_matches_import(
547            &declaration_path,
548            &Path::new("other.ts").into()
549        ));
550
551        assert!(!declaration_path_matches_import(
552            &declaration_path,
553            &Path::new("/home/user/project/src/other.ts").into()
554        ));
555    }
556}