declaration_scoring.rs

  1use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
  2use collections::HashMap;
  3use itertools::Itertools as _;
  4use language::BufferSnapshot;
  5use ordered_float::OrderedFloat;
  6use serde::Serialize;
  7use std::{cmp::Reverse, ops::Range};
  8use strum::EnumIter;
  9use text::{Point, ToPoint};
 10
 11use crate::{
 12    Declaration, EditPredictionExcerpt, Identifier,
 13    reference::{Reference, ReferenceRegion},
 14    syntax_index::SyntaxIndexState,
 15    text_similarity::{Occurrences, jaccard_similarity, weighted_overlap_coefficient},
 16};
 17
 18const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
 19
 20#[derive(Clone, Debug)]
 21pub struct ScoredDeclaration {
 22    pub identifier: Identifier,
 23    pub declaration: Declaration,
 24    pub score_components: DeclarationScoreComponents,
 25    pub scores: DeclarationScores,
 26}
 27
 28#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
 29pub enum DeclarationStyle {
 30    Signature,
 31    Declaration,
 32}
 33
 34impl ScoredDeclaration {
 35    /// Returns the score for this declaration with the specified style.
 36    pub fn score(&self, style: DeclarationStyle) -> f32 {
 37        match style {
 38            DeclarationStyle::Signature => self.scores.signature,
 39            DeclarationStyle::Declaration => self.scores.declaration,
 40        }
 41    }
 42
 43    pub fn size(&self, style: DeclarationStyle) -> usize {
 44        match &self.declaration {
 45            Declaration::File { declaration, .. } => match style {
 46                DeclarationStyle::Signature => declaration.signature_range.len(),
 47                DeclarationStyle::Declaration => declaration.text.len(),
 48            },
 49            Declaration::Buffer { declaration, .. } => match style {
 50                DeclarationStyle::Signature => declaration.signature_range.len(),
 51                DeclarationStyle::Declaration => declaration.item_range.len(),
 52            },
 53        }
 54    }
 55
 56    pub fn score_density(&self, style: DeclarationStyle) -> f32 {
 57        self.score(style) / (self.size(style)) as f32
 58    }
 59}
 60
 61pub fn scored_declarations(
 62    index: &SyntaxIndexState,
 63    excerpt: &EditPredictionExcerpt,
 64    excerpt_occurrences: &Occurrences,
 65    adjacent_occurrences: &Occurrences,
 66    identifier_to_references: HashMap<Identifier, Vec<Reference>>,
 67    cursor_offset: usize,
 68    current_buffer: &BufferSnapshot,
 69) -> Vec<ScoredDeclaration> {
 70    let cursor_point = cursor_offset.to_point(&current_buffer);
 71
 72    let mut declarations = identifier_to_references
 73        .into_iter()
 74        .flat_map(|(identifier, references)| {
 75            let declarations =
 76                index.declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier);
 77            let declaration_count = declarations.len();
 78
 79            declarations
 80                .into_iter()
 81                .filter_map(|(declaration_id, declaration)| match declaration {
 82                    Declaration::Buffer {
 83                        buffer_id,
 84                        declaration: buffer_declaration,
 85                        ..
 86                    } => {
 87                        let is_same_file = buffer_id == &current_buffer.remote_id();
 88
 89                        if is_same_file {
 90                            let overlaps_excerpt =
 91                                range_intersection(&buffer_declaration.item_range, &excerpt.range)
 92                                    .is_some();
 93                            if overlaps_excerpt
 94                                || excerpt
 95                                    .parent_declarations
 96                                    .iter()
 97                                    .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id)
 98                            {
 99                                None
100                            } else {
101                                let declaration_line = buffer_declaration
102                                    .item_range
103                                    .start
104                                    .to_point(current_buffer)
105                                    .row;
106                                Some((
107                                    true,
108                                    (cursor_point.row as i32 - declaration_line as i32)
109                                        .unsigned_abs(),
110                                    declaration,
111                                ))
112                            }
113                        } else {
114                            Some((false, u32::MAX, declaration))
115                        }
116                    }
117                    Declaration::File { .. } => {
118                        // We can assume that a file declaration is in a different file,
119                        // because the current one must be open
120                        Some((false, u32::MAX, declaration))
121                    }
122                })
123                .sorted_by_key(|&(_, distance, _)| distance)
124                .enumerate()
125                .map(
126                    |(
127                        declaration_line_distance_rank,
128                        (is_same_file, declaration_line_distance, declaration),
129                    )| {
130                        let same_file_declaration_count = index.file_declaration_count(declaration);
131
132                        score_declaration(
133                            &identifier,
134                            &references,
135                            declaration.clone(),
136                            is_same_file,
137                            declaration_line_distance,
138                            declaration_line_distance_rank,
139                            same_file_declaration_count,
140                            declaration_count,
141                            &excerpt_occurrences,
142                            &adjacent_occurrences,
143                            cursor_point,
144                            current_buffer,
145                        )
146                    },
147                )
148                .collect::<Vec<_>>()
149        })
150        .flatten()
151        .collect::<Vec<_>>();
152
153    declarations.sort_unstable_by_key(|declaration| {
154        let score_density = declaration
155            .score_density(DeclarationStyle::Declaration)
156            .max(declaration.score_density(DeclarationStyle::Signature));
157        Reverse(OrderedFloat(score_density))
158    });
159
160    declarations
161}
162
163fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
164    let start = a.start.clone().max(b.start.clone());
165    let end = a.end.clone().min(b.end.clone());
166    if start < end {
167        Some(Range { start, end })
168    } else {
169        None
170    }
171}
172
173fn score_declaration(
174    identifier: &Identifier,
175    references: &[Reference],
176    declaration: Declaration,
177    is_same_file: bool,
178    declaration_line_distance: u32,
179    declaration_line_distance_rank: usize,
180    same_file_declaration_count: usize,
181    declaration_count: usize,
182    excerpt_occurrences: &Occurrences,
183    adjacent_occurrences: &Occurrences,
184    cursor: Point,
185    current_buffer: &BufferSnapshot,
186) -> Option<ScoredDeclaration> {
187    let is_referenced_nearby = references
188        .iter()
189        .any(|r| r.region == ReferenceRegion::Nearby);
190    let is_referenced_in_breadcrumb = references
191        .iter()
192        .any(|r| r.region == ReferenceRegion::Breadcrumb);
193    let reference_count = references.len();
194    let reference_line_distance = references
195        .iter()
196        .map(|r| {
197            let reference_line = r.range.start.to_point(current_buffer).row as i32;
198            (cursor.row as i32 - reference_line).unsigned_abs()
199        })
200        .min()
201        .unwrap();
202
203    let item_source_occurrences = Occurrences::within_string(&declaration.item_text().0);
204    let item_signature_occurrences = Occurrences::within_string(&declaration.signature_text().0);
205    let excerpt_vs_item_jaccard = jaccard_similarity(excerpt_occurrences, &item_source_occurrences);
206    let excerpt_vs_signature_jaccard =
207        jaccard_similarity(excerpt_occurrences, &item_signature_occurrences);
208    let adjacent_vs_item_jaccard =
209        jaccard_similarity(adjacent_occurrences, &item_source_occurrences);
210    let adjacent_vs_signature_jaccard =
211        jaccard_similarity(adjacent_occurrences, &item_signature_occurrences);
212
213    let excerpt_vs_item_weighted_overlap =
214        weighted_overlap_coefficient(excerpt_occurrences, &item_source_occurrences);
215    let excerpt_vs_signature_weighted_overlap =
216        weighted_overlap_coefficient(excerpt_occurrences, &item_signature_occurrences);
217    let adjacent_vs_item_weighted_overlap =
218        weighted_overlap_coefficient(adjacent_occurrences, &item_source_occurrences);
219    let adjacent_vs_signature_weighted_overlap =
220        weighted_overlap_coefficient(adjacent_occurrences, &item_signature_occurrences);
221
222    // TODO: Consider adding declaration_file_count
223    let score_components = DeclarationScoreComponents {
224        is_same_file,
225        is_referenced_nearby,
226        is_referenced_in_breadcrumb,
227        reference_line_distance,
228        declaration_line_distance,
229        declaration_line_distance_rank,
230        reference_count,
231        same_file_declaration_count,
232        declaration_count,
233        excerpt_vs_item_jaccard,
234        excerpt_vs_signature_jaccard,
235        adjacent_vs_item_jaccard,
236        adjacent_vs_signature_jaccard,
237        excerpt_vs_item_weighted_overlap,
238        excerpt_vs_signature_weighted_overlap,
239        adjacent_vs_item_weighted_overlap,
240        adjacent_vs_signature_weighted_overlap,
241    };
242
243    Some(ScoredDeclaration {
244        identifier: identifier.clone(),
245        declaration: declaration,
246        scores: DeclarationScores::score(&score_components),
247        score_components,
248    })
249}
250
251#[derive(Clone, Debug, Serialize)]
252pub struct DeclarationScores {
253    pub signature: f32,
254    pub declaration: f32,
255    pub retrieval: f32,
256}
257
258impl DeclarationScores {
259    fn score(components: &DeclarationScoreComponents) -> DeclarationScores {
260        // TODO: handle truncation
261
262        // Score related to how likely this is the correct declaration, range 0 to 1
263        let retrieval = if components.is_same_file {
264            // TODO: use declaration_line_distance_rank
265            1.0 / components.same_file_declaration_count as f32
266        } else {
267            1.0 / components.declaration_count as f32
268        };
269
270        // Score related to the distance between the reference and cursor, range 0 to 1
271        let distance_score = if components.is_referenced_nearby {
272            1.0 / (1.0 + components.reference_line_distance as f32 / 10.0).powf(2.0)
273        } else {
274            // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
275            0.5
276        };
277
278        // For now instead of linear combination, the scores are just multiplied together.
279        let combined_score = 10.0 * retrieval * distance_score;
280
281        DeclarationScores {
282            signature: combined_score * components.excerpt_vs_signature_weighted_overlap,
283            // declaration score gets boosted both by being multiplied by 2 and by there being more
284            // weighted overlap.
285            declaration: 2.0 * combined_score * components.excerpt_vs_item_weighted_overlap,
286            retrieval,
287        }
288    }
289}