declaration_scoring.rs

  1use cloud_llm_client::predict_edits_v3::ScoreComponents;
  2use itertools::Itertools as _;
  3use language::BufferSnapshot;
  4use ordered_float::OrderedFloat;
  5use serde::Serialize;
  6use std::{collections::HashMap, ops::Range};
  7use strum::EnumIter;
  8use text::{Point, ToPoint};
  9
 10use crate::{
 11    Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier,
 12    reference::{Reference, ReferenceRegion},
 13    syntax_index::SyntaxIndexState,
 14    text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
 15};
 16
 17const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
 18
 19// TODO:
 20//
 21// * Consider adding declaration_file_count
 22
 23#[derive(Clone, Debug)]
 24pub struct ScoredSnippet {
 25    pub identifier: Identifier,
 26    pub declaration: Declaration,
 27    pub score_components: ScoreComponents,
 28    pub scores: Scores,
 29}
 30
 31// TODO: Consider having "Concise" style corresponding to `concise_text`
 32#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
 33pub enum SnippetStyle {
 34    Signature,
 35    Declaration,
 36}
 37
 38impl ScoredSnippet {
 39    /// Returns the score for this snippet with the specified style.
 40    pub fn score(&self, style: SnippetStyle) -> f32 {
 41        match style {
 42            SnippetStyle::Signature => self.scores.signature,
 43            SnippetStyle::Declaration => self.scores.declaration,
 44        }
 45    }
 46
 47    pub fn size(&self, style: SnippetStyle) -> usize {
 48        // TODO: how to handle truncation?
 49        match &self.declaration {
 50            Declaration::File { declaration, .. } => match style {
 51                SnippetStyle::Signature => declaration.signature_range_in_text.len(),
 52                SnippetStyle::Declaration => declaration.text.len(),
 53            },
 54            Declaration::Buffer { declaration, .. } => match style {
 55                SnippetStyle::Signature => declaration.signature_range.len(),
 56                SnippetStyle::Declaration => declaration.item_range.len(),
 57            },
 58        }
 59    }
 60
 61    pub fn score_density(&self, style: SnippetStyle) -> f32 {
 62        self.score(style) / (self.size(style)) as f32
 63    }
 64}
 65
 66pub fn scored_snippets(
 67    index: &SyntaxIndexState,
 68    excerpt: &EditPredictionExcerpt,
 69    excerpt_text: &EditPredictionExcerptText,
 70    identifier_to_references: HashMap<Identifier, Vec<Reference>>,
 71    cursor_offset: usize,
 72    current_buffer: &BufferSnapshot,
 73) -> Vec<ScoredSnippet> {
 74    let containing_range_identifier_occurrences =
 75        IdentifierOccurrences::within_string(&excerpt_text.body);
 76    let cursor_point = cursor_offset.to_point(&current_buffer);
 77
 78    let start_point = Point::new(cursor_point.row.saturating_sub(2), 0);
 79    let end_point = Point::new(cursor_point.row + 1, 0);
 80    let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
 81        &current_buffer
 82            .text_for_range(start_point..end_point)
 83            .collect::<String>(),
 84    );
 85
 86    let mut snippets = identifier_to_references
 87        .into_iter()
 88        .flat_map(|(identifier, references)| {
 89            let declarations =
 90                index.declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier);
 91            let declaration_count = declarations.len();
 92
 93            declarations
 94                .into_iter()
 95                .filter_map(|(declaration_id, declaration)| match declaration {
 96                    Declaration::Buffer {
 97                        buffer_id,
 98                        declaration: buffer_declaration,
 99                        ..
100                    } => {
101                        let is_same_file = buffer_id == &current_buffer.remote_id();
102
103                        if is_same_file {
104                            let overlaps_excerpt =
105                                range_intersection(&buffer_declaration.item_range, &excerpt.range)
106                                    .is_some();
107                            if overlaps_excerpt
108                                || excerpt
109                                    .parent_declarations
110                                    .iter()
111                                    .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id)
112                            {
113                                None
114                            } else {
115                                let declaration_line = buffer_declaration
116                                    .item_range
117                                    .start
118                                    .to_point(current_buffer)
119                                    .row;
120                                Some((
121                                    true,
122                                    (cursor_point.row as i32 - declaration_line as i32)
123                                        .unsigned_abs(),
124                                    declaration,
125                                ))
126                            }
127                        } else {
128                            Some((false, u32::MAX, declaration))
129                        }
130                    }
131                    Declaration::File { .. } => {
132                        // We can assume that a file declaration is in a different file,
133                        // because the current one must be open
134                        Some((false, u32::MAX, declaration))
135                    }
136                })
137                .sorted_by_key(|&(_, distance, _)| distance)
138                .enumerate()
139                .map(
140                    |(
141                        declaration_line_distance_rank,
142                        (is_same_file, declaration_line_distance, declaration),
143                    )| {
144                        let same_file_declaration_count = index.file_declaration_count(declaration);
145
146                        score_snippet(
147                            &identifier,
148                            &references,
149                            declaration.clone(),
150                            is_same_file,
151                            declaration_line_distance,
152                            declaration_line_distance_rank,
153                            same_file_declaration_count,
154                            declaration_count,
155                            &containing_range_identifier_occurrences,
156                            &adjacent_identifier_occurrences,
157                            cursor_point,
158                            current_buffer,
159                        )
160                    },
161                )
162                .collect::<Vec<_>>()
163        })
164        .flatten()
165        .collect::<Vec<_>>();
166
167    snippets.sort_unstable_by_key(|snippet| {
168        OrderedFloat(
169            snippet
170                .score_density(SnippetStyle::Declaration)
171                .max(snippet.score_density(SnippetStyle::Signature)),
172        )
173    });
174
175    snippets
176}
177
178fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
179    let start = a.start.clone().max(b.start.clone());
180    let end = a.end.clone().min(b.end.clone());
181    if start < end {
182        Some(Range { start, end })
183    } else {
184        None
185    }
186}
187
188fn score_snippet(
189    identifier: &Identifier,
190    references: &[Reference],
191    declaration: Declaration,
192    is_same_file: bool,
193    declaration_line_distance: u32,
194    declaration_line_distance_rank: usize,
195    same_file_declaration_count: usize,
196    declaration_count: usize,
197    containing_range_identifier_occurrences: &IdentifierOccurrences,
198    adjacent_identifier_occurrences: &IdentifierOccurrences,
199    cursor: Point,
200    current_buffer: &BufferSnapshot,
201) -> Option<ScoredSnippet> {
202    let is_referenced_nearby = references
203        .iter()
204        .any(|r| r.region == ReferenceRegion::Nearby);
205    let is_referenced_in_breadcrumb = references
206        .iter()
207        .any(|r| r.region == ReferenceRegion::Breadcrumb);
208    let reference_count = references.len();
209    let reference_line_distance = references
210        .iter()
211        .map(|r| {
212            let reference_line = r.range.start.to_point(current_buffer).row as i32;
213            (cursor.row as i32 - reference_line).unsigned_abs()
214        })
215        .min()
216        .unwrap();
217
218    let item_source_occurrences = IdentifierOccurrences::within_string(&declaration.item_text().0);
219    let item_signature_occurrences =
220        IdentifierOccurrences::within_string(&declaration.signature_text().0);
221    let containing_range_vs_item_jaccard = jaccard_similarity(
222        containing_range_identifier_occurrences,
223        &item_source_occurrences,
224    );
225    let containing_range_vs_signature_jaccard = jaccard_similarity(
226        containing_range_identifier_occurrences,
227        &item_signature_occurrences,
228    );
229    let adjacent_vs_item_jaccard =
230        jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences);
231    let adjacent_vs_signature_jaccard =
232        jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences);
233
234    let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient(
235        containing_range_identifier_occurrences,
236        &item_source_occurrences,
237    );
238    let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient(
239        containing_range_identifier_occurrences,
240        &item_signature_occurrences,
241    );
242    let adjacent_vs_item_weighted_overlap =
243        weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences);
244    let adjacent_vs_signature_weighted_overlap =
245        weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
246
247    let score_components = ScoreComponents {
248        is_same_file,
249        is_referenced_nearby,
250        is_referenced_in_breadcrumb,
251        reference_line_distance,
252        declaration_line_distance,
253        declaration_line_distance_rank,
254        reference_count,
255        same_file_declaration_count,
256        declaration_count,
257        containing_range_vs_item_jaccard,
258        containing_range_vs_signature_jaccard,
259        adjacent_vs_item_jaccard,
260        adjacent_vs_signature_jaccard,
261        containing_range_vs_item_weighted_overlap,
262        containing_range_vs_signature_weighted_overlap,
263        adjacent_vs_item_weighted_overlap,
264        adjacent_vs_signature_weighted_overlap,
265    };
266
267    Some(ScoredSnippet {
268        identifier: identifier.clone(),
269        declaration: declaration,
270        scores: Scores::score(&score_components),
271        score_components,
272    })
273}
274
275#[derive(Clone, Debug, Serialize)]
276pub struct Scores {
277    pub signature: f32,
278    pub declaration: f32,
279}
280
281impl Scores {
282    fn score(components: &ScoreComponents) -> Scores {
283        // Score related to how likely this is the correct declaration, range 0 to 1
284        let accuracy_score = if components.is_same_file {
285            // TODO: use declaration_line_distance_rank
286            1.0 / components.same_file_declaration_count as f32
287        } else {
288            1.0 / components.declaration_count as f32
289        };
290
291        // Score related to the distance between the reference and cursor, range 0 to 1
292        let distance_score = if components.is_referenced_nearby {
293            1.0 / (1.0 + components.reference_line_distance as f32 / 10.0).powf(2.0)
294        } else {
295            // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
296            0.5
297        };
298
299        // For now instead of linear combination, the scores are just multiplied together.
300        let combined_score = 10.0 * accuracy_score * distance_score;
301
302        Scores {
303            signature: combined_score * components.containing_range_vs_signature_weighted_overlap,
304            // declaration score gets boosted both by being multiplied by 2 and by there being more
305            // weighted overlap.
306            declaration: 2.0
307                * combined_score
308                * components.containing_range_vs_item_weighted_overlap,
309        }
310    }
311}