1use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
  2use collections::HashMap;
  3use language::BufferSnapshot;
  4use ordered_float::OrderedFloat;
  5use project::ProjectEntryId;
  6use serde::Serialize;
  7use std::{cmp::Reverse, ops::Range, path::Path, sync::Arc};
  8use strum::EnumIter;
  9use text::{Point, ToPoint};
 10use util::RangeExt as _;
 11
 12use crate::{
 13    CachedDeclarationPath, Declaration, EditPredictionExcerpt, Identifier,
 14    imports::{Import, Imports, Module},
 15    reference::{Reference, ReferenceRegion},
 16    syntax_index::SyntaxIndexState,
 17    text_similarity::{Occurrences, jaccard_similarity, weighted_overlap_coefficient},
 18};
 19
 20const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
 21
 22#[derive(Clone, Debug, PartialEq, Eq)]
 23pub struct EditPredictionScoreOptions {
 24    pub omit_excerpt_overlaps: bool,
 25}
 26
 27#[derive(Clone, Debug)]
 28pub struct ScoredDeclaration {
 29    /// identifier used by the local reference
 30    pub identifier: Identifier,
 31    pub declaration: Declaration,
 32    pub components: DeclarationScoreComponents,
 33}
 34
 35#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
 36pub enum DeclarationStyle {
 37    Signature,
 38    Declaration,
 39}
 40
 41#[derive(Clone, Debug, Serialize, Default)]
 42pub struct DeclarationScores {
 43    pub signature: f32,
 44    pub declaration: f32,
 45    pub retrieval: f32,
 46}
 47
 48impl ScoredDeclaration {
 49    /// Returns the score for this declaration with the specified style.
 50    pub fn score(&self, style: DeclarationStyle) -> f32 {
 51        // TODO: handle truncation
 52
 53        // Score related to how likely this is the correct declaration, range 0 to 1
 54        let retrieval = self.retrieval_score();
 55
 56        // Score related to the distance between the reference and cursor, range 0 to 1
 57        let distance_score = if self.components.is_referenced_nearby {
 58            1.0 / (1.0 + self.components.reference_line_distance as f32 / 10.0).powf(2.0)
 59        } else {
 60            // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
 61            0.5
 62        };
 63
 64        // For now instead of linear combination, the scores are just multiplied together.
 65        let combined_score = 10.0 * retrieval * distance_score;
 66
 67        match style {
 68            DeclarationStyle::Signature => {
 69                combined_score * self.components.excerpt_vs_signature_weighted_overlap
 70            }
 71            DeclarationStyle::Declaration => {
 72                2.0 * combined_score * self.components.excerpt_vs_item_weighted_overlap
 73            }
 74        }
 75    }
 76
 77    pub fn retrieval_score(&self) -> f32 {
 78        let mut score = if self.components.is_same_file {
 79            10.0 / self.components.same_file_declaration_count as f32
 80        } else if self.components.path_import_match_count > 0 {
 81            3.0
 82        } else if self.components.wildcard_path_import_match_count > 0 {
 83            1.0
 84        } else if self.components.normalized_import_similarity > 0.0 {
 85            self.components.normalized_import_similarity
 86        } else if self.components.normalized_wildcard_import_similarity > 0.0 {
 87            0.5 * self.components.normalized_wildcard_import_similarity
 88        } else {
 89            1.0 / self.components.declaration_count as f32
 90        };
 91        score *= 1. + self.components.included_by_others as f32 / 2.;
 92        score *= 1. + self.components.includes_others as f32 / 4.;
 93        score
 94    }
 95
 96    pub fn size(&self, style: DeclarationStyle) -> usize {
 97        match &self.declaration {
 98            Declaration::File { declaration, .. } => match style {
 99                DeclarationStyle::Signature => declaration.signature_range.len(),
100                DeclarationStyle::Declaration => declaration.text.len(),
101            },
102            Declaration::Buffer { declaration, .. } => match style {
103                DeclarationStyle::Signature => declaration.signature_range.len(),
104                DeclarationStyle::Declaration => declaration.item_range.len(),
105            },
106        }
107    }
108
109    pub fn score_density(&self, style: DeclarationStyle) -> f32 {
110        self.score(style) / self.size(style) as f32
111    }
112}
113
114pub fn scored_declarations(
115    options: &EditPredictionScoreOptions,
116    index: &SyntaxIndexState,
117    excerpt: &EditPredictionExcerpt,
118    excerpt_occurrences: &Occurrences,
119    adjacent_occurrences: &Occurrences,
120    imports: &Imports,
121    identifier_to_references: HashMap<Identifier, Vec<Reference>>,
122    cursor_offset: usize,
123    current_buffer: &BufferSnapshot,
124) -> Vec<ScoredDeclaration> {
125    let cursor_point = cursor_offset.to_point(¤t_buffer);
126
127    let mut wildcard_import_occurrences = Vec::new();
128    let mut wildcard_import_paths = Vec::new();
129    for wildcard_import in imports.wildcard_modules.iter() {
130        match wildcard_import {
131            Module::Namespace(namespace) => {
132                wildcard_import_occurrences.push(namespace.occurrences())
133            }
134            Module::SourceExact(path) => wildcard_import_paths.push(path),
135            Module::SourceFuzzy(path) => {
136                wildcard_import_occurrences.push(Occurrences::from_path(&path))
137            }
138        }
139    }
140
141    let mut scored_declarations = Vec::new();
142    let mut project_entry_id_to_outline_ranges: HashMap<ProjectEntryId, Vec<Range<usize>>> =
143        HashMap::default();
144    for (identifier, references) in identifier_to_references {
145        let mut import_occurrences = Vec::new();
146        let mut import_paths = Vec::new();
147        let mut found_external_identifier: Option<&Identifier> = None;
148
149        if let Some(imports) = imports.identifier_to_imports.get(&identifier) {
150            // only use alias when it's the only import, could be generalized if some language
151            // has overlapping aliases
152            //
153            // TODO: when an aliased declaration is included in the prompt, should include the
154            // aliasing in the prompt.
155            //
156            // TODO: For SourceFuzzy consider having componentwise comparison that pays
157            // attention to ordering.
158            if let [
159                Import::Alias {
160                    module,
161                    external_identifier,
162                },
163            ] = imports.as_slice()
164            {
165                match module {
166                    Module::Namespace(namespace) => {
167                        import_occurrences.push(namespace.occurrences())
168                    }
169                    Module::SourceExact(path) => import_paths.push(path),
170                    Module::SourceFuzzy(path) => {
171                        import_occurrences.push(Occurrences::from_path(&path))
172                    }
173                }
174                found_external_identifier = Some(&external_identifier);
175            } else {
176                for import in imports {
177                    match import {
178                        Import::Direct { module } => match module {
179                            Module::Namespace(namespace) => {
180                                import_occurrences.push(namespace.occurrences())
181                            }
182                            Module::SourceExact(path) => import_paths.push(path),
183                            Module::SourceFuzzy(path) => {
184                                import_occurrences.push(Occurrences::from_path(&path))
185                            }
186                        },
187                        Import::Alias { .. } => {}
188                    }
189                }
190            }
191        }
192
193        let identifier_to_lookup = found_external_identifier.unwrap_or(&identifier);
194        // TODO: update this to be able to return more declarations? Especially if there is the
195        // ability to quickly filter a large list (based on imports)
196        let identifier_declarations = index
197            .declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier_to_lookup);
198        let declaration_count = identifier_declarations.len();
199
200        if declaration_count == 0 {
201            continue;
202        }
203
204        // TODO: option to filter out other candidates when same file / import match
205        let mut checked_declarations = Vec::with_capacity(declaration_count);
206        for (declaration_id, declaration) in identifier_declarations {
207            match declaration {
208                Declaration::Buffer {
209                    buffer_id,
210                    declaration: buffer_declaration,
211                    ..
212                } => {
213                    if buffer_id == ¤t_buffer.remote_id() {
214                        let already_included_in_prompt =
215                            range_intersection(&buffer_declaration.item_range, &excerpt.range)
216                                .is_some()
217                                || excerpt
218                                    .parent_declarations
219                                    .iter()
220                                    .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id);
221                        if !options.omit_excerpt_overlaps || !already_included_in_prompt {
222                            let declaration_line = buffer_declaration
223                                .item_range
224                                .start
225                                .to_point(current_buffer)
226                                .row;
227                            let declaration_line_distance =
228                                (cursor_point.row as i32 - declaration_line as i32).unsigned_abs();
229                            checked_declarations.push(CheckedDeclaration {
230                                declaration,
231                                same_file_line_distance: Some(declaration_line_distance),
232                                path_import_match_count: 0,
233                                wildcard_path_import_match_count: 0,
234                            });
235                        }
236                        continue;
237                    } else {
238                    }
239                }
240                Declaration::File { .. } => {}
241            }
242            let declaration_path = declaration.cached_path();
243            let path_import_match_count = import_paths
244                .iter()
245                .filter(|import_path| {
246                    declaration_path_matches_import(&declaration_path, import_path)
247                })
248                .count();
249            let wildcard_path_import_match_count = wildcard_import_paths
250                .iter()
251                .filter(|import_path| {
252                    declaration_path_matches_import(&declaration_path, import_path)
253                })
254                .count();
255            checked_declarations.push(CheckedDeclaration {
256                declaration,
257                same_file_line_distance: None,
258                path_import_match_count,
259                wildcard_path_import_match_count,
260            });
261        }
262
263        let mut max_import_similarity = 0.0;
264        let mut max_wildcard_import_similarity = 0.0;
265
266        let mut scored_declarations_for_identifier = Vec::with_capacity(checked_declarations.len());
267        for checked_declaration in checked_declarations {
268            let same_file_declaration_count =
269                index.file_declaration_count(checked_declaration.declaration);
270
271            let declaration = score_declaration(
272                &identifier,
273                &references,
274                checked_declaration,
275                same_file_declaration_count,
276                declaration_count,
277                &excerpt_occurrences,
278                &adjacent_occurrences,
279                &import_occurrences,
280                &wildcard_import_occurrences,
281                cursor_point,
282                current_buffer,
283            );
284
285            if declaration.components.import_similarity > max_import_similarity {
286                max_import_similarity = declaration.components.import_similarity;
287            }
288
289            if declaration.components.wildcard_import_similarity > max_wildcard_import_similarity {
290                max_wildcard_import_similarity = declaration.components.wildcard_import_similarity;
291            }
292
293            project_entry_id_to_outline_ranges
294                .entry(declaration.declaration.project_entry_id())
295                .or_default()
296                .push(declaration.declaration.item_range());
297            scored_declarations_for_identifier.push(declaration);
298        }
299
300        if max_import_similarity > 0.0 || max_wildcard_import_similarity > 0.0 {
301            for declaration in scored_declarations_for_identifier.iter_mut() {
302                if max_import_similarity > 0.0 {
303                    declaration.components.max_import_similarity = max_import_similarity;
304                    declaration.components.normalized_import_similarity =
305                        declaration.components.import_similarity / max_import_similarity;
306                }
307                if max_wildcard_import_similarity > 0.0 {
308                    declaration.components.normalized_wildcard_import_similarity =
309                        declaration.components.wildcard_import_similarity
310                            / max_wildcard_import_similarity;
311                }
312            }
313        }
314
315        scored_declarations.extend(scored_declarations_for_identifier);
316    }
317
318    // TODO: Inform this via import / retrieval scores of outline items
319    // TODO: Consider using a sweepline
320    for scored_declaration in scored_declarations.iter_mut() {
321        let project_entry_id = scored_declaration.declaration.project_entry_id();
322        let Some(ranges) = project_entry_id_to_outline_ranges.get(&project_entry_id) else {
323            continue;
324        };
325        for range in ranges {
326            if range.contains_inclusive(&scored_declaration.declaration.item_range()) {
327                scored_declaration.components.included_by_others += 1
328            } else if scored_declaration
329                .declaration
330                .item_range()
331                .contains_inclusive(range)
332            {
333                scored_declaration.components.includes_others += 1
334            }
335        }
336    }
337
338    scored_declarations.sort_unstable_by_key(|declaration| {
339        Reverse(OrderedFloat(
340            declaration.score(DeclarationStyle::Declaration),
341        ))
342    });
343
344    scored_declarations
345}
346
347struct CheckedDeclaration<'a> {
348    declaration: &'a Declaration,
349    same_file_line_distance: Option<u32>,
350    path_import_match_count: usize,
351    wildcard_path_import_match_count: usize,
352}
353
354fn declaration_path_matches_import(
355    declaration_path: &CachedDeclarationPath,
356    import_path: &Arc<Path>,
357) -> bool {
358    if import_path.is_absolute() {
359        declaration_path.equals_absolute_path(import_path)
360    } else {
361        declaration_path.ends_with_posix_path(import_path)
362    }
363}
364
365fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
366    let start = a.start.clone().max(b.start.clone());
367    let end = a.end.clone().min(b.end.clone());
368    if start < end {
369        Some(Range { start, end })
370    } else {
371        None
372    }
373}
374
375fn score_declaration(
376    identifier: &Identifier,
377    references: &[Reference],
378    checked_declaration: CheckedDeclaration,
379    same_file_declaration_count: usize,
380    declaration_count: usize,
381    excerpt_occurrences: &Occurrences,
382    adjacent_occurrences: &Occurrences,
383    import_occurrences: &[Occurrences],
384    wildcard_import_occurrences: &[Occurrences],
385    cursor: Point,
386    current_buffer: &BufferSnapshot,
387) -> ScoredDeclaration {
388    let CheckedDeclaration {
389        declaration,
390        same_file_line_distance,
391        path_import_match_count,
392        wildcard_path_import_match_count,
393    } = checked_declaration;
394
395    let is_referenced_nearby = references
396        .iter()
397        .any(|r| r.region == ReferenceRegion::Nearby);
398    let is_referenced_in_breadcrumb = references
399        .iter()
400        .any(|r| r.region == ReferenceRegion::Breadcrumb);
401    let reference_count = references.len();
402    let reference_line_distance = references
403        .iter()
404        .map(|r| {
405            let reference_line = r.range.start.to_point(current_buffer).row as i32;
406            (cursor.row as i32 - reference_line).unsigned_abs()
407        })
408        .min()
409        .unwrap();
410
411    let is_same_file = same_file_line_distance.is_some();
412    let declaration_line_distance = same_file_line_distance.unwrap_or(u32::MAX);
413
414    let item_source_occurrences = Occurrences::within_string(&declaration.item_text().0);
415    let item_signature_occurrences = Occurrences::within_string(&declaration.signature_text().0);
416    let excerpt_vs_item_jaccard = jaccard_similarity(excerpt_occurrences, &item_source_occurrences);
417    let excerpt_vs_signature_jaccard =
418        jaccard_similarity(excerpt_occurrences, &item_signature_occurrences);
419    let adjacent_vs_item_jaccard =
420        jaccard_similarity(adjacent_occurrences, &item_source_occurrences);
421    let adjacent_vs_signature_jaccard =
422        jaccard_similarity(adjacent_occurrences, &item_signature_occurrences);
423
424    let excerpt_vs_item_weighted_overlap =
425        weighted_overlap_coefficient(excerpt_occurrences, &item_source_occurrences);
426    let excerpt_vs_signature_weighted_overlap =
427        weighted_overlap_coefficient(excerpt_occurrences, &item_signature_occurrences);
428    let adjacent_vs_item_weighted_overlap =
429        weighted_overlap_coefficient(adjacent_occurrences, &item_source_occurrences);
430    let adjacent_vs_signature_weighted_overlap =
431        weighted_overlap_coefficient(adjacent_occurrences, &item_signature_occurrences);
432
433    let mut import_similarity = 0f32;
434    let mut wildcard_import_similarity = 0f32;
435    if !import_occurrences.is_empty() || !wildcard_import_occurrences.is_empty() {
436        let cached_path = declaration.cached_path();
437        let path_occurrences = Occurrences::from_worktree_path(
438            cached_path
439                .worktree_abs_path
440                .file_name()
441                .map(|f| f.to_string_lossy()),
442            &cached_path.rel_path,
443        );
444        import_similarity = import_occurrences
445            .iter()
446            .map(|namespace_occurrences| {
447                OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences))
448            })
449            .max()
450            .map(|similarity| similarity.into_inner())
451            .unwrap_or_default();
452
453        // TODO: Consider something other than max
454        wildcard_import_similarity = wildcard_import_occurrences
455            .iter()
456            .map(|namespace_occurrences| {
457                OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences))
458            })
459            .max()
460            .map(|similarity| similarity.into_inner())
461            .unwrap_or_default();
462    }
463
464    // TODO: Consider adding declaration_file_count
465    let score_components = DeclarationScoreComponents {
466        is_same_file,
467        is_referenced_nearby,
468        is_referenced_in_breadcrumb,
469        reference_line_distance,
470        declaration_line_distance,
471        reference_count,
472        same_file_declaration_count,
473        declaration_count,
474        excerpt_vs_item_jaccard,
475        excerpt_vs_signature_jaccard,
476        adjacent_vs_item_jaccard,
477        adjacent_vs_signature_jaccard,
478        excerpt_vs_item_weighted_overlap,
479        excerpt_vs_signature_weighted_overlap,
480        adjacent_vs_item_weighted_overlap,
481        adjacent_vs_signature_weighted_overlap,
482        path_import_match_count,
483        wildcard_path_import_match_count,
484        import_similarity,
485        max_import_similarity: 0.0,
486        normalized_import_similarity: 0.0,
487        wildcard_import_similarity,
488        normalized_wildcard_import_similarity: 0.0,
489        included_by_others: 0,
490        includes_others: 0,
491    };
492
493    ScoredDeclaration {
494        identifier: identifier.clone(),
495        declaration: declaration.clone(),
496        components: score_components,
497    }
498}
499
500#[cfg(test)]
501mod test {
502    use super::*;
503
504    #[test]
505    fn test_declaration_path_matches() {
506        let declaration_path =
507            CachedDeclarationPath::new_for_test("/home/user/project", "src/maths.ts");
508
509        assert!(declaration_path_matches_import(
510            &declaration_path,
511            &Path::new("maths.ts").into()
512        ));
513
514        assert!(declaration_path_matches_import(
515            &declaration_path,
516            &Path::new("project/src/maths.ts").into()
517        ));
518
519        assert!(declaration_path_matches_import(
520            &declaration_path,
521            &Path::new("user/project/src/maths.ts").into()
522        ));
523
524        assert!(declaration_path_matches_import(
525            &declaration_path,
526            &Path::new("/home/user/project/src/maths.ts").into()
527        ));
528
529        assert!(!declaration_path_matches_import(
530            &declaration_path,
531            &Path::new("other.ts").into()
532        ));
533
534        assert!(!declaration_path_matches_import(
535            &declaration_path,
536            &Path::new("/home/user/project/src/other.ts").into()
537        ));
538    }
539}