declaration_scoring.rs

  1use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
  2use collections::HashMap;
  3use language::BufferSnapshot;
  4use ordered_float::OrderedFloat;
  5use serde::Serialize;
  6use std::{cmp::Reverse, ops::Range, path::Path, sync::Arc};
  7use strum::EnumIter;
  8use text::{Point, ToPoint};
  9
 10use crate::{
 11    CachedDeclarationPath, Declaration, EditPredictionExcerpt, Identifier,
 12    imports::{Import, Imports, Module},
 13    reference::{Reference, ReferenceRegion},
 14    syntax_index::SyntaxIndexState,
 15    text_similarity::{Occurrences, jaccard_similarity, weighted_overlap_coefficient},
 16};
 17
 18const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
 19
 20#[derive(Clone, Debug, PartialEq, Eq)]
 21pub struct EditPredictionScoreOptions {
 22    pub omit_excerpt_overlaps: bool,
 23}
 24
 25#[derive(Clone, Debug)]
 26pub struct ScoredDeclaration {
 27    /// identifier used by the local reference
 28    pub identifier: Identifier,
 29    pub declaration: Declaration,
 30    pub components: DeclarationScoreComponents,
 31}
 32
 33#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
 34pub enum DeclarationStyle {
 35    Signature,
 36    Declaration,
 37}
 38
 39#[derive(Clone, Debug, Serialize, Default)]
 40pub struct DeclarationScores {
 41    pub signature: f32,
 42    pub declaration: f32,
 43    pub retrieval: f32,
 44}
 45
 46impl ScoredDeclaration {
 47    /// Returns the score for this declaration with the specified style.
 48    pub fn score(&self, style: DeclarationStyle) -> f32 {
 49        // TODO: handle truncation
 50
 51        // Score related to how likely this is the correct declaration, range 0 to 1
 52        let retrieval = self.retrieval_score();
 53
 54        // Score related to the distance between the reference and cursor, range 0 to 1
 55        let distance_score = if self.components.is_referenced_nearby {
 56            1.0 / (1.0 + self.components.reference_line_distance as f32 / 10.0).powf(2.0)
 57        } else {
 58            // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
 59            0.5
 60        };
 61
 62        // For now instead of linear combination, the scores are just multiplied together.
 63        let combined_score = 10.0 * retrieval * distance_score;
 64
 65        match style {
 66            DeclarationStyle::Signature => {
 67                combined_score * self.components.excerpt_vs_signature_weighted_overlap
 68            }
 69            DeclarationStyle::Declaration => {
 70                2.0 * combined_score * self.components.excerpt_vs_item_weighted_overlap
 71            }
 72        }
 73    }
 74
 75    pub fn retrieval_score(&self) -> f32 {
 76        if self.components.is_same_file {
 77            10.0 / self.components.same_file_declaration_count as f32
 78        } else if self.components.path_import_match_count > 0 {
 79            3.0
 80        } else if self.components.wildcard_path_import_match_count > 0 {
 81            1.0
 82        } else if self.components.normalized_import_similarity > 0.0 {
 83            self.components.normalized_import_similarity
 84        } else if self.components.normalized_wildcard_import_similarity > 0.0 {
 85            0.5 * self.components.normalized_wildcard_import_similarity
 86        } else {
 87            1.0 / self.components.declaration_count as f32
 88        }
 89    }
 90
 91    pub fn size(&self, style: DeclarationStyle) -> usize {
 92        match &self.declaration {
 93            Declaration::File { declaration, .. } => match style {
 94                DeclarationStyle::Signature => declaration.signature_range.len(),
 95                DeclarationStyle::Declaration => declaration.text.len(),
 96            },
 97            Declaration::Buffer { declaration, .. } => match style {
 98                DeclarationStyle::Signature => declaration.signature_range.len(),
 99                DeclarationStyle::Declaration => declaration.item_range.len(),
100            },
101        }
102    }
103
104    pub fn score_density(&self, style: DeclarationStyle) -> f32 {
105        self.score(style) / self.size(style) as f32
106    }
107}
108
109pub fn scored_declarations(
110    options: &EditPredictionScoreOptions,
111    index: &SyntaxIndexState,
112    excerpt: &EditPredictionExcerpt,
113    excerpt_occurrences: &Occurrences,
114    adjacent_occurrences: &Occurrences,
115    imports: &Imports,
116    identifier_to_references: HashMap<Identifier, Vec<Reference>>,
117    cursor_offset: usize,
118    current_buffer: &BufferSnapshot,
119) -> Vec<ScoredDeclaration> {
120    let cursor_point = cursor_offset.to_point(&current_buffer);
121
122    let mut wildcard_import_occurrences = Vec::new();
123    let mut wildcard_import_paths = Vec::new();
124    for wildcard_import in imports.wildcard_modules.iter() {
125        match wildcard_import {
126            Module::Namespace(namespace) => {
127                wildcard_import_occurrences.push(namespace.occurrences())
128            }
129            Module::SourceExact(path) => wildcard_import_paths.push(path),
130            Module::SourceFuzzy(path) => {
131                wildcard_import_occurrences.push(Occurrences::from_path(&path))
132            }
133        }
134    }
135
136    let mut declarations = identifier_to_references
137        .into_iter()
138        .flat_map(|(identifier, references)| {
139            let mut import_occurrences = Vec::new();
140            let mut import_paths = Vec::new();
141            let mut found_external_identifier: Option<&Identifier> = None;
142
143            if let Some(imports) = imports.identifier_to_imports.get(&identifier) {
144                // only use alias when it's the only import, could be generalized if some language
145                // has overlapping aliases
146                //
147                // TODO: when an aliased declaration is included in the prompt, should include the
148                // aliasing in the prompt.
149                //
150                // TODO: For SourceFuzzy consider having componentwise comparison that pays
151                // attention to ordering.
152                if let [
153                    Import::Alias {
154                        module,
155                        external_identifier,
156                    },
157                ] = imports.as_slice()
158                {
159                    match module {
160                        Module::Namespace(namespace) => {
161                            import_occurrences.push(namespace.occurrences())
162                        }
163                        Module::SourceExact(path) => import_paths.push(path),
164                        Module::SourceFuzzy(path) => {
165                            import_occurrences.push(Occurrences::from_path(&path))
166                        }
167                    }
168                    found_external_identifier = Some(&external_identifier);
169                } else {
170                    for import in imports {
171                        match import {
172                            Import::Direct { module } => match module {
173                                Module::Namespace(namespace) => {
174                                    import_occurrences.push(namespace.occurrences())
175                                }
176                                Module::SourceExact(path) => import_paths.push(path),
177                                Module::SourceFuzzy(path) => {
178                                    import_occurrences.push(Occurrences::from_path(&path))
179                                }
180                            },
181                            Import::Alias { .. } => {}
182                        }
183                    }
184                }
185            }
186
187            let identifier_to_lookup = found_external_identifier.unwrap_or(&identifier);
188            // TODO: update this to be able to return more declarations? Especially if there is the
189            // ability to quickly filter a large list (based on imports)
190            let declarations = index
191                .declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(
192                    &identifier_to_lookup,
193                );
194            let declaration_count = declarations.len();
195
196            if declaration_count == 0 {
197                return Vec::new();
198            }
199
200            // TODO: option to filter out other candidates when same file / import match
201            let mut checked_declarations = Vec::new();
202            for (declaration_id, declaration) in declarations {
203                match declaration {
204                    Declaration::Buffer {
205                        buffer_id,
206                        declaration: buffer_declaration,
207                        ..
208                    } => {
209                        if buffer_id == &current_buffer.remote_id() {
210                            let already_included_in_prompt =
211                                range_intersection(&buffer_declaration.item_range, &excerpt.range)
212                                    .is_some()
213                                    || excerpt.parent_declarations.iter().any(
214                                        |(excerpt_parent, _)| excerpt_parent == &declaration_id,
215                                    );
216                            if !options.omit_excerpt_overlaps || !already_included_in_prompt {
217                                let declaration_line = buffer_declaration
218                                    .item_range
219                                    .start
220                                    .to_point(current_buffer)
221                                    .row;
222                                let declaration_line_distance = (cursor_point.row as i32
223                                    - declaration_line as i32)
224                                    .unsigned_abs();
225                                checked_declarations.push(CheckedDeclaration {
226                                    declaration,
227                                    same_file_line_distance: Some(declaration_line_distance),
228                                    path_import_match_count: 0,
229                                    wildcard_path_import_match_count: 0,
230                                });
231                            }
232                            continue;
233                        } else {
234                        }
235                    }
236                    Declaration::File { .. } => {}
237                }
238                let declaration_path = declaration.cached_path();
239                let path_import_match_count = import_paths
240                    .iter()
241                    .filter(|import_path| {
242                        declaration_path_matches_import(&declaration_path, import_path)
243                    })
244                    .count();
245                let wildcard_path_import_match_count = wildcard_import_paths
246                    .iter()
247                    .filter(|import_path| {
248                        declaration_path_matches_import(&declaration_path, import_path)
249                    })
250                    .count();
251                checked_declarations.push(CheckedDeclaration {
252                    declaration,
253                    same_file_line_distance: None,
254                    path_import_match_count,
255                    wildcard_path_import_match_count,
256                });
257            }
258
259            let mut max_import_similarity = 0.0;
260            let mut max_wildcard_import_similarity = 0.0;
261
262            let mut scored_declarations_for_identifier = checked_declarations
263                .into_iter()
264                .map(|checked_declaration| {
265                    let same_file_declaration_count =
266                        index.file_declaration_count(checked_declaration.declaration);
267
268                    let declaration = score_declaration(
269                        &identifier,
270                        &references,
271                        checked_declaration,
272                        same_file_declaration_count,
273                        declaration_count,
274                        &excerpt_occurrences,
275                        &adjacent_occurrences,
276                        &import_occurrences,
277                        &wildcard_import_occurrences,
278                        cursor_point,
279                        current_buffer,
280                    );
281
282                    if declaration.components.import_similarity > max_import_similarity {
283                        max_import_similarity = declaration.components.import_similarity;
284                    }
285
286                    if declaration.components.wildcard_import_similarity
287                        > max_wildcard_import_similarity
288                    {
289                        max_wildcard_import_similarity =
290                            declaration.components.wildcard_import_similarity;
291                    }
292
293                    declaration
294                })
295                .collect::<Vec<_>>();
296
297            if max_import_similarity > 0.0 || max_wildcard_import_similarity > 0.0 {
298                for declaration in scored_declarations_for_identifier.iter_mut() {
299                    if max_import_similarity > 0.0 {
300                        declaration.components.max_import_similarity = max_import_similarity;
301                        declaration.components.normalized_import_similarity =
302                            declaration.components.import_similarity / max_import_similarity;
303                    }
304                    if max_wildcard_import_similarity > 0.0 {
305                        declaration.components.normalized_wildcard_import_similarity =
306                            declaration.components.wildcard_import_similarity
307                                / max_wildcard_import_similarity;
308                    }
309                }
310            }
311
312            scored_declarations_for_identifier
313        })
314        .collect::<Vec<_>>();
315
316    declarations.sort_unstable_by_key(|declaration| {
317        let score_density = declaration
318            .score_density(DeclarationStyle::Declaration)
319            .max(declaration.score_density(DeclarationStyle::Signature));
320        Reverse(OrderedFloat(score_density))
321    });
322
323    declarations
324}
325
326struct CheckedDeclaration<'a> {
327    declaration: &'a Declaration,
328    same_file_line_distance: Option<u32>,
329    path_import_match_count: usize,
330    wildcard_path_import_match_count: usize,
331}
332
333fn declaration_path_matches_import(
334    declaration_path: &CachedDeclarationPath,
335    import_path: &Arc<Path>,
336) -> bool {
337    if import_path.is_absolute() {
338        declaration_path.equals_absolute_path(import_path)
339    } else {
340        declaration_path.ends_with_posix_path(import_path)
341    }
342}
343
344fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
345    let start = a.start.clone().max(b.start.clone());
346    let end = a.end.clone().min(b.end.clone());
347    if start < end {
348        Some(Range { start, end })
349    } else {
350        None
351    }
352}
353
354fn score_declaration(
355    identifier: &Identifier,
356    references: &[Reference],
357    checked_declaration: CheckedDeclaration,
358    same_file_declaration_count: usize,
359    declaration_count: usize,
360    excerpt_occurrences: &Occurrences,
361    adjacent_occurrences: &Occurrences,
362    import_occurrences: &[Occurrences],
363    wildcard_import_occurrences: &[Occurrences],
364    cursor: Point,
365    current_buffer: &BufferSnapshot,
366) -> ScoredDeclaration {
367    let CheckedDeclaration {
368        declaration,
369        same_file_line_distance,
370        path_import_match_count,
371        wildcard_path_import_match_count,
372    } = checked_declaration;
373
374    let is_referenced_nearby = references
375        .iter()
376        .any(|r| r.region == ReferenceRegion::Nearby);
377    let is_referenced_in_breadcrumb = references
378        .iter()
379        .any(|r| r.region == ReferenceRegion::Breadcrumb);
380    let reference_count = references.len();
381    let reference_line_distance = references
382        .iter()
383        .map(|r| {
384            let reference_line = r.range.start.to_point(current_buffer).row as i32;
385            (cursor.row as i32 - reference_line).unsigned_abs()
386        })
387        .min()
388        .unwrap();
389
390    let is_same_file = same_file_line_distance.is_some();
391    let declaration_line_distance = same_file_line_distance.unwrap_or(u32::MAX);
392
393    let item_source_occurrences = Occurrences::within_string(&declaration.item_text().0);
394    let item_signature_occurrences = Occurrences::within_string(&declaration.signature_text().0);
395    let excerpt_vs_item_jaccard = jaccard_similarity(excerpt_occurrences, &item_source_occurrences);
396    let excerpt_vs_signature_jaccard =
397        jaccard_similarity(excerpt_occurrences, &item_signature_occurrences);
398    let adjacent_vs_item_jaccard =
399        jaccard_similarity(adjacent_occurrences, &item_source_occurrences);
400    let adjacent_vs_signature_jaccard =
401        jaccard_similarity(adjacent_occurrences, &item_signature_occurrences);
402
403    let excerpt_vs_item_weighted_overlap =
404        weighted_overlap_coefficient(excerpt_occurrences, &item_source_occurrences);
405    let excerpt_vs_signature_weighted_overlap =
406        weighted_overlap_coefficient(excerpt_occurrences, &item_signature_occurrences);
407    let adjacent_vs_item_weighted_overlap =
408        weighted_overlap_coefficient(adjacent_occurrences, &item_source_occurrences);
409    let adjacent_vs_signature_weighted_overlap =
410        weighted_overlap_coefficient(adjacent_occurrences, &item_signature_occurrences);
411
412    let mut import_similarity = 0f32;
413    let mut wildcard_import_similarity = 0f32;
414    if !import_occurrences.is_empty() || !wildcard_import_occurrences.is_empty() {
415        let cached_path = declaration.cached_path();
416        let path_occurrences = Occurrences::from_worktree_path(
417            cached_path
418                .worktree_abs_path
419                .file_name()
420                .map(|f| f.to_string_lossy()),
421            &cached_path.rel_path,
422        );
423        import_similarity = import_occurrences
424            .iter()
425            .map(|namespace_occurrences| {
426                OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences))
427            })
428            .max()
429            .map(|similarity| similarity.into_inner())
430            .unwrap_or_default();
431
432        // TODO: Consider something other than max
433        wildcard_import_similarity = wildcard_import_occurrences
434            .iter()
435            .map(|namespace_occurrences| {
436                OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences))
437            })
438            .max()
439            .map(|similarity| similarity.into_inner())
440            .unwrap_or_default();
441    }
442
443    // TODO: Consider adding declaration_file_count
444    let score_components = DeclarationScoreComponents {
445        is_same_file,
446        is_referenced_nearby,
447        is_referenced_in_breadcrumb,
448        reference_line_distance,
449        declaration_line_distance,
450        reference_count,
451        same_file_declaration_count,
452        declaration_count,
453        excerpt_vs_item_jaccard,
454        excerpt_vs_signature_jaccard,
455        adjacent_vs_item_jaccard,
456        adjacent_vs_signature_jaccard,
457        excerpt_vs_item_weighted_overlap,
458        excerpt_vs_signature_weighted_overlap,
459        adjacent_vs_item_weighted_overlap,
460        adjacent_vs_signature_weighted_overlap,
461        path_import_match_count,
462        wildcard_path_import_match_count,
463        import_similarity,
464        max_import_similarity: 0.0,
465        normalized_import_similarity: 0.0,
466        wildcard_import_similarity,
467        normalized_wildcard_import_similarity: 0.0,
468    };
469
470    ScoredDeclaration {
471        identifier: identifier.clone(),
472        declaration: declaration.clone(),
473        components: score_components,
474    }
475}
476
477#[cfg(test)]
478mod test {
479    use super::*;
480
481    #[test]
482    fn test_declaration_path_matches() {
483        let declaration_path =
484            CachedDeclarationPath::new_for_test("/home/user/project", "src/maths.ts");
485
486        assert!(declaration_path_matches_import(
487            &declaration_path,
488            &Path::new("maths.ts").into()
489        ));
490
491        assert!(declaration_path_matches_import(
492            &declaration_path,
493            &Path::new("project/src/maths.ts").into()
494        ));
495
496        assert!(declaration_path_matches_import(
497            &declaration_path,
498            &Path::new("user/project/src/maths.ts").into()
499        ));
500
501        assert!(declaration_path_matches_import(
502            &declaration_path,
503            &Path::new("/home/user/project/src/maths.ts").into()
504        ));
505
506        assert!(!declaration_path_matches_import(
507            &declaration_path,
508            &Path::new("other.ts").into()
509        ));
510
511        assert!(!declaration_path_matches_import(
512            &declaration_path,
513            &Path::new("/home/user/project/src/other.ts").into()
514        ));
515    }
516}