declaration_scoring.rs

  1use cloud_llm_client::predict_edits_v3::ScoreComponents;
  2use itertools::Itertools as _;
  3use language::BufferSnapshot;
  4use ordered_float::OrderedFloat;
  5use serde::Serialize;
  6use std::{cmp::Reverse, collections::HashMap, ops::Range};
  7use strum::EnumIter;
  8use text::{Point, ToPoint};
  9
 10use crate::{
 11    Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier,
 12    reference::{Reference, ReferenceRegion},
 13    syntax_index::SyntaxIndexState,
 14    text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
 15};
 16
 17const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
 18
 19#[derive(Clone, Debug)]
 20pub struct ScoredSnippet {
 21    pub identifier: Identifier,
 22    pub declaration: Declaration,
 23    pub score_components: ScoreComponents,
 24    pub scores: Scores,
 25}
 26
 27#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
 28pub enum SnippetStyle {
 29    Signature,
 30    Declaration,
 31}
 32
 33impl ScoredSnippet {
 34    /// Returns the score for this snippet with the specified style.
 35    pub fn score(&self, style: SnippetStyle) -> f32 {
 36        match style {
 37            SnippetStyle::Signature => self.scores.signature,
 38            SnippetStyle::Declaration => self.scores.declaration,
 39        }
 40    }
 41
 42    pub fn size(&self, style: SnippetStyle) -> usize {
 43        match &self.declaration {
 44            Declaration::File { declaration, .. } => match style {
 45                SnippetStyle::Signature => declaration.signature_range.len(),
 46                SnippetStyle::Declaration => declaration.text.len(),
 47            },
 48            Declaration::Buffer { declaration, .. } => match style {
 49                SnippetStyle::Signature => declaration.signature_range.len(),
 50                SnippetStyle::Declaration => declaration.item_range.len(),
 51            },
 52        }
 53    }
 54
 55    pub fn score_density(&self, style: SnippetStyle) -> f32 {
 56        self.score(style) / (self.size(style)) as f32
 57    }
 58}
 59
 60pub fn scored_snippets(
 61    index: &SyntaxIndexState,
 62    excerpt: &EditPredictionExcerpt,
 63    excerpt_text: &EditPredictionExcerptText,
 64    identifier_to_references: HashMap<Identifier, Vec<Reference>>,
 65    cursor_offset: usize,
 66    current_buffer: &BufferSnapshot,
 67) -> Vec<ScoredSnippet> {
 68    let containing_range_identifier_occurrences =
 69        IdentifierOccurrences::within_string(&excerpt_text.body);
 70    let cursor_point = cursor_offset.to_point(&current_buffer);
 71
 72    let start_point = Point::new(cursor_point.row.saturating_sub(2), 0);
 73    let end_point = Point::new(cursor_point.row + 1, 0);
 74    let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
 75        &current_buffer
 76            .text_for_range(start_point..end_point)
 77            .collect::<String>(),
 78    );
 79
 80    let mut snippets = identifier_to_references
 81        .into_iter()
 82        .flat_map(|(identifier, references)| {
 83            let declarations =
 84                index.declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier);
 85            let declaration_count = declarations.len();
 86
 87            declarations
 88                .into_iter()
 89                .filter_map(|(declaration_id, declaration)| match declaration {
 90                    Declaration::Buffer {
 91                        buffer_id,
 92                        declaration: buffer_declaration,
 93                        ..
 94                    } => {
 95                        let is_same_file = buffer_id == &current_buffer.remote_id();
 96
 97                        if is_same_file {
 98                            let overlaps_excerpt =
 99                                range_intersection(&buffer_declaration.item_range, &excerpt.range)
100                                    .is_some();
101                            if overlaps_excerpt
102                                || excerpt
103                                    .parent_declarations
104                                    .iter()
105                                    .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id)
106                            {
107                                None
108                            } else {
109                                let declaration_line = buffer_declaration
110                                    .item_range
111                                    .start
112                                    .to_point(current_buffer)
113                                    .row;
114                                Some((
115                                    true,
116                                    (cursor_point.row as i32 - declaration_line as i32)
117                                        .unsigned_abs(),
118                                    declaration,
119                                ))
120                            }
121                        } else {
122                            Some((false, u32::MAX, declaration))
123                        }
124                    }
125                    Declaration::File { .. } => {
126                        // We can assume that a file declaration is in a different file,
127                        // because the current one must be open
128                        Some((false, u32::MAX, declaration))
129                    }
130                })
131                .sorted_by_key(|&(_, distance, _)| distance)
132                .enumerate()
133                .map(
134                    |(
135                        declaration_line_distance_rank,
136                        (is_same_file, declaration_line_distance, declaration),
137                    )| {
138                        let same_file_declaration_count = index.file_declaration_count(declaration);
139
140                        score_snippet(
141                            &identifier,
142                            &references,
143                            declaration.clone(),
144                            is_same_file,
145                            declaration_line_distance,
146                            declaration_line_distance_rank,
147                            same_file_declaration_count,
148                            declaration_count,
149                            &containing_range_identifier_occurrences,
150                            &adjacent_identifier_occurrences,
151                            cursor_point,
152                            current_buffer,
153                        )
154                    },
155                )
156                .collect::<Vec<_>>()
157        })
158        .flatten()
159        .collect::<Vec<_>>();
160
161    snippets.sort_unstable_by_key(|snippet| {
162        let score_density = snippet
163            .score_density(SnippetStyle::Declaration)
164            .max(snippet.score_density(SnippetStyle::Signature));
165        Reverse(OrderedFloat(score_density))
166    });
167
168    snippets
169}
170
171fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
172    let start = a.start.clone().max(b.start.clone());
173    let end = a.end.clone().min(b.end.clone());
174    if start < end {
175        Some(Range { start, end })
176    } else {
177        None
178    }
179}
180
181fn score_snippet(
182    identifier: &Identifier,
183    references: &[Reference],
184    declaration: Declaration,
185    is_same_file: bool,
186    declaration_line_distance: u32,
187    declaration_line_distance_rank: usize,
188    same_file_declaration_count: usize,
189    declaration_count: usize,
190    containing_range_identifier_occurrences: &IdentifierOccurrences,
191    adjacent_identifier_occurrences: &IdentifierOccurrences,
192    cursor: Point,
193    current_buffer: &BufferSnapshot,
194) -> Option<ScoredSnippet> {
195    let is_referenced_nearby = references
196        .iter()
197        .any(|r| r.region == ReferenceRegion::Nearby);
198    let is_referenced_in_breadcrumb = references
199        .iter()
200        .any(|r| r.region == ReferenceRegion::Breadcrumb);
201    let reference_count = references.len();
202    let reference_line_distance = references
203        .iter()
204        .map(|r| {
205            let reference_line = r.range.start.to_point(current_buffer).row as i32;
206            (cursor.row as i32 - reference_line).unsigned_abs()
207        })
208        .min()
209        .unwrap();
210
211    let item_source_occurrences = IdentifierOccurrences::within_string(&declaration.item_text().0);
212    let item_signature_occurrences =
213        IdentifierOccurrences::within_string(&declaration.signature_text().0);
214    let containing_range_vs_item_jaccard = jaccard_similarity(
215        containing_range_identifier_occurrences,
216        &item_source_occurrences,
217    );
218    let containing_range_vs_signature_jaccard = jaccard_similarity(
219        containing_range_identifier_occurrences,
220        &item_signature_occurrences,
221    );
222    let adjacent_vs_item_jaccard =
223        jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences);
224    let adjacent_vs_signature_jaccard =
225        jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences);
226
227    let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient(
228        containing_range_identifier_occurrences,
229        &item_source_occurrences,
230    );
231    let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient(
232        containing_range_identifier_occurrences,
233        &item_signature_occurrences,
234    );
235    let adjacent_vs_item_weighted_overlap =
236        weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences);
237    let adjacent_vs_signature_weighted_overlap =
238        weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
239
240    // TODO: Consider adding declaration_file_count
241    let score_components = ScoreComponents {
242        is_same_file,
243        is_referenced_nearby,
244        is_referenced_in_breadcrumb,
245        reference_line_distance,
246        declaration_line_distance,
247        declaration_line_distance_rank,
248        reference_count,
249        same_file_declaration_count,
250        declaration_count,
251        containing_range_vs_item_jaccard,
252        containing_range_vs_signature_jaccard,
253        adjacent_vs_item_jaccard,
254        adjacent_vs_signature_jaccard,
255        containing_range_vs_item_weighted_overlap,
256        containing_range_vs_signature_weighted_overlap,
257        adjacent_vs_item_weighted_overlap,
258        adjacent_vs_signature_weighted_overlap,
259    };
260
261    Some(ScoredSnippet {
262        identifier: identifier.clone(),
263        declaration: declaration,
264        scores: Scores::score(&score_components),
265        score_components,
266    })
267}
268
269#[derive(Clone, Debug, Serialize)]
270pub struct Scores {
271    pub signature: f32,
272    pub declaration: f32,
273}
274
275impl Scores {
276    fn score(components: &ScoreComponents) -> Scores {
277        // TODO: handle truncation
278
279        // Score related to how likely this is the correct declaration, range 0 to 1
280        let accuracy_score = if components.is_same_file {
281            // TODO: use declaration_line_distance_rank
282            1.0 / components.same_file_declaration_count as f32
283        } else {
284            1.0 / components.declaration_count as f32
285        };
286
287        // Score related to the distance between the reference and cursor, range 0 to 1
288        let distance_score = if components.is_referenced_nearby {
289            1.0 / (1.0 + components.reference_line_distance as f32 / 10.0).powf(2.0)
290        } else {
291            // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
292            0.5
293        };
294
295        // For now instead of linear combination, the scores are just multiplied together.
296        let combined_score = 10.0 * accuracy_score * distance_score;
297
298        Scores {
299            signature: combined_score * components.containing_range_vs_signature_weighted_overlap,
300            // declaration score gets boosted both by being multiplied by 2 and by there being more
301            // weighted overlap.
302            declaration: 2.0
303                * combined_score
304                * components.containing_range_vs_item_weighted_overlap,
305        }
306    }
307}