declaration_scoring.rs

  1use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
  2use itertools::Itertools as _;
  3use language::BufferSnapshot;
  4use ordered_float::OrderedFloat;
  5use serde::Serialize;
  6use std::{cmp::Reverse, collections::HashMap, ops::Range};
  7use strum::EnumIter;
  8use text::{Point, ToPoint};
  9
 10use crate::{
 11    Declaration, EditPredictionExcerpt, Identifier,
 12    reference::{Reference, ReferenceRegion},
 13    syntax_index::SyntaxIndexState,
 14    text_similarity::{Occurrences, jaccard_similarity, weighted_overlap_coefficient},
 15};
 16
 17const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
 18
 19#[derive(Clone, Debug)]
 20pub struct ScoredDeclaration {
 21    pub identifier: Identifier,
 22    pub declaration: Declaration,
 23    pub score_components: DeclarationScoreComponents,
 24    pub scores: DeclarationScores,
 25}
 26
 27#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
 28pub enum DeclarationStyle {
 29    Signature,
 30    Declaration,
 31}
 32
 33impl ScoredDeclaration {
 34    /// Returns the score for this declaration with the specified style.
 35    pub fn score(&self, style: DeclarationStyle) -> f32 {
 36        match style {
 37            DeclarationStyle::Signature => self.scores.signature,
 38            DeclarationStyle::Declaration => self.scores.declaration,
 39        }
 40    }
 41
 42    pub fn size(&self, style: DeclarationStyle) -> usize {
 43        match &self.declaration {
 44            Declaration::File { declaration, .. } => match style {
 45                DeclarationStyle::Signature => declaration.signature_range.len(),
 46                DeclarationStyle::Declaration => declaration.text.len(),
 47            },
 48            Declaration::Buffer { declaration, .. } => match style {
 49                DeclarationStyle::Signature => declaration.signature_range.len(),
 50                DeclarationStyle::Declaration => declaration.item_range.len(),
 51            },
 52        }
 53    }
 54
 55    pub fn score_density(&self, style: DeclarationStyle) -> f32 {
 56        self.score(style) / (self.size(style)) as f32
 57    }
 58}
 59
 60pub fn scored_declarations(
 61    index: &SyntaxIndexState,
 62    excerpt: &EditPredictionExcerpt,
 63    excerpt_occurrences: &Occurrences,
 64    adjacent_occurrences: &Occurrences,
 65    identifier_to_references: HashMap<Identifier, Vec<Reference>>,
 66    cursor_offset: usize,
 67    current_buffer: &BufferSnapshot,
 68) -> Vec<ScoredDeclaration> {
 69    let cursor_point = cursor_offset.to_point(&current_buffer);
 70
 71    let mut declarations = identifier_to_references
 72        .into_iter()
 73        .flat_map(|(identifier, references)| {
 74            let declarations =
 75                index.declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier);
 76            let declaration_count = declarations.len();
 77
 78            declarations
 79                .into_iter()
 80                .filter_map(|(declaration_id, declaration)| match declaration {
 81                    Declaration::Buffer {
 82                        buffer_id,
 83                        declaration: buffer_declaration,
 84                        ..
 85                    } => {
 86                        let is_same_file = buffer_id == &current_buffer.remote_id();
 87
 88                        if is_same_file {
 89                            let overlaps_excerpt =
 90                                range_intersection(&buffer_declaration.item_range, &excerpt.range)
 91                                    .is_some();
 92                            if overlaps_excerpt
 93                                || excerpt
 94                                    .parent_declarations
 95                                    .iter()
 96                                    .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id)
 97                            {
 98                                None
 99                            } else {
100                                let declaration_line = buffer_declaration
101                                    .item_range
102                                    .start
103                                    .to_point(current_buffer)
104                                    .row;
105                                Some((
106                                    true,
107                                    (cursor_point.row as i32 - declaration_line as i32)
108                                        .unsigned_abs(),
109                                    declaration,
110                                ))
111                            }
112                        } else {
113                            Some((false, u32::MAX, declaration))
114                        }
115                    }
116                    Declaration::File { .. } => {
117                        // We can assume that a file declaration is in a different file,
118                        // because the current one must be open
119                        Some((false, u32::MAX, declaration))
120                    }
121                })
122                .sorted_by_key(|&(_, distance, _)| distance)
123                .enumerate()
124                .map(
125                    |(
126                        declaration_line_distance_rank,
127                        (is_same_file, declaration_line_distance, declaration),
128                    )| {
129                        let same_file_declaration_count = index.file_declaration_count(declaration);
130
131                        score_declaration(
132                            &identifier,
133                            &references,
134                            declaration.clone(),
135                            is_same_file,
136                            declaration_line_distance,
137                            declaration_line_distance_rank,
138                            same_file_declaration_count,
139                            declaration_count,
140                            &excerpt_occurrences,
141                            &adjacent_occurrences,
142                            cursor_point,
143                            current_buffer,
144                        )
145                    },
146                )
147                .collect::<Vec<_>>()
148        })
149        .flatten()
150        .collect::<Vec<_>>();
151
152    declarations.sort_unstable_by_key(|declaration| {
153        let score_density = declaration
154            .score_density(DeclarationStyle::Declaration)
155            .max(declaration.score_density(DeclarationStyle::Signature));
156        Reverse(OrderedFloat(score_density))
157    });
158
159    declarations
160}
161
162fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
163    let start = a.start.clone().max(b.start.clone());
164    let end = a.end.clone().min(b.end.clone());
165    if start < end {
166        Some(Range { start, end })
167    } else {
168        None
169    }
170}
171
172fn score_declaration(
173    identifier: &Identifier,
174    references: &[Reference],
175    declaration: Declaration,
176    is_same_file: bool,
177    declaration_line_distance: u32,
178    declaration_line_distance_rank: usize,
179    same_file_declaration_count: usize,
180    declaration_count: usize,
181    excerpt_occurrences: &Occurrences,
182    adjacent_occurrences: &Occurrences,
183    cursor: Point,
184    current_buffer: &BufferSnapshot,
185) -> Option<ScoredDeclaration> {
186    let is_referenced_nearby = references
187        .iter()
188        .any(|r| r.region == ReferenceRegion::Nearby);
189    let is_referenced_in_breadcrumb = references
190        .iter()
191        .any(|r| r.region == ReferenceRegion::Breadcrumb);
192    let reference_count = references.len();
193    let reference_line_distance = references
194        .iter()
195        .map(|r| {
196            let reference_line = r.range.start.to_point(current_buffer).row as i32;
197            (cursor.row as i32 - reference_line).unsigned_abs()
198        })
199        .min()
200        .unwrap();
201
202    let item_source_occurrences = Occurrences::within_string(&declaration.item_text().0);
203    let item_signature_occurrences = Occurrences::within_string(&declaration.signature_text().0);
204    let excerpt_vs_item_jaccard = jaccard_similarity(excerpt_occurrences, &item_source_occurrences);
205    let excerpt_vs_signature_jaccard =
206        jaccard_similarity(excerpt_occurrences, &item_signature_occurrences);
207    let adjacent_vs_item_jaccard =
208        jaccard_similarity(adjacent_occurrences, &item_source_occurrences);
209    let adjacent_vs_signature_jaccard =
210        jaccard_similarity(adjacent_occurrences, &item_signature_occurrences);
211
212    let excerpt_vs_item_weighted_overlap =
213        weighted_overlap_coefficient(excerpt_occurrences, &item_source_occurrences);
214    let excerpt_vs_signature_weighted_overlap =
215        weighted_overlap_coefficient(excerpt_occurrences, &item_signature_occurrences);
216    let adjacent_vs_item_weighted_overlap =
217        weighted_overlap_coefficient(adjacent_occurrences, &item_source_occurrences);
218    let adjacent_vs_signature_weighted_overlap =
219        weighted_overlap_coefficient(adjacent_occurrences, &item_signature_occurrences);
220
221    // TODO: Consider adding declaration_file_count
222    let score_components = DeclarationScoreComponents {
223        is_same_file,
224        is_referenced_nearby,
225        is_referenced_in_breadcrumb,
226        reference_line_distance,
227        declaration_line_distance,
228        declaration_line_distance_rank,
229        reference_count,
230        same_file_declaration_count,
231        declaration_count,
232        excerpt_vs_item_jaccard,
233        excerpt_vs_signature_jaccard,
234        adjacent_vs_item_jaccard,
235        adjacent_vs_signature_jaccard,
236        excerpt_vs_item_weighted_overlap,
237        excerpt_vs_signature_weighted_overlap,
238        adjacent_vs_item_weighted_overlap,
239        adjacent_vs_signature_weighted_overlap,
240    };
241
242    Some(ScoredDeclaration {
243        identifier: identifier.clone(),
244        declaration: declaration,
245        scores: DeclarationScores::score(&score_components),
246        score_components,
247    })
248}
249
250#[derive(Clone, Debug, Serialize)]
251pub struct DeclarationScores {
252    pub signature: f32,
253    pub declaration: f32,
254}
255
256impl DeclarationScores {
257    fn score(components: &DeclarationScoreComponents) -> DeclarationScores {
258        // TODO: handle truncation
259
260        // Score related to how likely this is the correct declaration, range 0 to 1
261        let accuracy_score = if components.is_same_file {
262            // TODO: use declaration_line_distance_rank
263            1.0 / components.same_file_declaration_count as f32
264        } else {
265            1.0 / components.declaration_count as f32
266        };
267
268        // Score related to the distance between the reference and cursor, range 0 to 1
269        let distance_score = if components.is_referenced_nearby {
270            1.0 / (1.0 + components.reference_line_distance as f32 / 10.0).powf(2.0)
271        } else {
272            // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
273            0.5
274        };
275
276        // For now instead of linear combination, the scores are just multiplied together.
277        let combined_score = 10.0 * accuracy_score * distance_score;
278
279        DeclarationScores {
280            signature: combined_score * components.excerpt_vs_signature_weighted_overlap,
281            // declaration score gets boosted both by being multiplied by 2 and by there being more
282            // weighted overlap.
283            declaration: 2.0 * combined_score * components.excerpt_vs_item_weighted_overlap,
284        }
285    }
286}