declaration_scoring.rs

  1use gpui::{App, Entity};
  2use itertools::Itertools as _;
  3use language::BufferSnapshot;
  4use serde::Serialize;
  5use std::{collections::HashMap, ops::Range};
  6use strum::EnumIter;
  7use text::{OffsetRangeExt, Point, ToPoint};
  8
  9use crate::{
 10    Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier, SyntaxIndex,
 11    reference::{Reference, ReferenceRegion},
 12    text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
 13};
 14
 15// TODO:
 16//
 17// * Consider adding declaration_file_count (n)
 18
 19#[derive(Clone, Debug)]
 20pub struct ScoredSnippet {
 21    #[allow(dead_code)]
 22    pub identifier: Identifier,
 23    pub declaration: Declaration,
 24    pub score_components: ScoreInputs,
 25    pub scores: Scores,
 26}
 27
 28// TODO: Consider having "Concise" style corresponding to `concise_text`
 29#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
 30pub enum SnippetStyle {
 31    Signature,
 32    Declaration,
 33}
 34
 35impl ScoredSnippet {
 36    /// Returns the score for this snippet with the specified style.
 37    pub fn score(&self, style: SnippetStyle) -> f32 {
 38        match style {
 39            SnippetStyle::Signature => self.scores.signature,
 40            SnippetStyle::Declaration => self.scores.declaration,
 41        }
 42    }
 43
 44    pub fn size(&self, style: SnippetStyle) -> usize {
 45        todo!()
 46    }
 47
 48    pub fn score_density(&self, style: SnippetStyle) -> f32 {
 49        self.score(style) / (self.size(style)) as f32
 50    }
 51}
 52
 53fn scored_snippets(
 54    index: Entity<SyntaxIndex>,
 55    excerpt: &EditPredictionExcerpt,
 56    excerpt_text: &EditPredictionExcerptText,
 57    identifier_to_references: HashMap<Identifier, Vec<Reference>>,
 58    cursor_offset: usize,
 59    current_buffer: &BufferSnapshot,
 60    cx: &App,
 61) -> Vec<ScoredSnippet> {
 62    let containing_range_identifier_occurrences =
 63        IdentifierOccurrences::within_string(&excerpt_text.body);
 64    let cursor_point = cursor_offset.to_point(&current_buffer);
 65
 66    let start_point = Point::new(cursor_point.row.saturating_sub(2), 0);
 67    let end_point = Point::new(cursor_point.row + 1, 0);
 68    let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
 69        &current_buffer
 70            .text_for_range(start_point..end_point)
 71            .collect::<String>(),
 72    );
 73
 74    identifier_to_references
 75        .into_iter()
 76        .flat_map(|(identifier, references)| {
 77            let declarations = index
 78                .read(cx)
 79                // todo! pick a limit
 80                .declarations_for_identifier::<16>(&identifier, cx);
 81            let declaration_count = declarations.len();
 82
 83            declarations
 84                .iter()
 85                .filter_map(|declaration| match declaration {
 86                    Declaration::Buffer {
 87                        declaration: buffer_declaration,
 88                        buffer,
 89                    } => {
 90                        let is_same_file = buffer
 91                            .read_with(cx, |buffer, _| buffer.remote_id())
 92                            .is_ok_and(|buffer_id| buffer_id == current_buffer.remote_id());
 93
 94                        if is_same_file {
 95                            range_intersection(
 96                                &buffer_declaration.item_range.to_offset(&current_buffer),
 97                                &excerpt.range,
 98                            )
 99                            .is_none()
100                            .then(|| {
101                                let declaration_line = buffer_declaration
102                                    .item_range
103                                    .start
104                                    .to_point(current_buffer)
105                                    .row;
106                                (
107                                    true,
108                                    (cursor_point.row as i32 - declaration_line as i32).abs()
109                                        as u32,
110                                    declaration,
111                                )
112                            })
113                        } else {
114                            Some((false, 0, declaration))
115                        }
116                    }
117                    Declaration::File { .. } => {
118                        // We can assume that a file declaration is in a different file,
119                        // because the current one must be open
120                        Some((false, 0, declaration))
121                    }
122                })
123                .sorted_by_key(|&(_, distance, _)| distance)
124                .enumerate()
125                .map(
126                    |(
127                        declaration_line_distance_rank,
128                        (is_same_file, declaration_line_distance, declaration),
129                    )| {
130                        let same_file_declaration_count =
131                            index.read(cx).file_declaration_count(declaration);
132
133                        score_snippet(
134                            &identifier,
135                            &references,
136                            declaration.clone(),
137                            is_same_file,
138                            declaration_line_distance,
139                            declaration_line_distance_rank,
140                            same_file_declaration_count,
141                            declaration_count,
142                            &containing_range_identifier_occurrences,
143                            &adjacent_identifier_occurrences,
144                            cursor_point,
145                            current_buffer,
146                            cx,
147                        )
148                    },
149                )
150                .collect::<Vec<_>>()
151        })
152        .flatten()
153        .collect::<Vec<_>>()
154}
155
156// todo! replace with existing util?
157fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
158    let start = a.start.clone().max(b.start.clone());
159    let end = a.end.clone().min(b.end.clone());
160    if start < end {
161        Some(Range { start, end })
162    } else {
163        None
164    }
165}
166
167fn score_snippet(
168    identifier: &Identifier,
169    references: &[Reference],
170    declaration: Declaration,
171    is_same_file: bool,
172    declaration_line_distance: u32,
173    declaration_line_distance_rank: usize,
174    same_file_declaration_count: usize,
175    declaration_count: usize,
176    containing_range_identifier_occurrences: &IdentifierOccurrences,
177    adjacent_identifier_occurrences: &IdentifierOccurrences,
178    cursor: Point,
179    current_buffer: &BufferSnapshot,
180    cx: &App,
181) -> Option<ScoredSnippet> {
182    let is_referenced_nearby = references
183        .iter()
184        .any(|r| r.region == ReferenceRegion::Nearby);
185    let is_referenced_in_breadcrumb = references
186        .iter()
187        .any(|r| r.region == ReferenceRegion::Breadcrumb);
188    let reference_count = references.len();
189    let reference_line_distance = references
190        .iter()
191        .map(|r| {
192            let reference_line = r.range.start.to_point(current_buffer).row as i32;
193            (cursor.row as i32 - reference_line).abs() as u32
194        })
195        .min()
196        .unwrap();
197
198    let item_source_occurrences =
199        IdentifierOccurrences::within_string(&declaration.item_text(cx).0);
200    let item_signature_occurrences =
201        IdentifierOccurrences::within_string(&declaration.signature_text(cx).0);
202    let containing_range_vs_item_jaccard = jaccard_similarity(
203        containing_range_identifier_occurrences,
204        &item_source_occurrences,
205    );
206    let containing_range_vs_signature_jaccard = jaccard_similarity(
207        containing_range_identifier_occurrences,
208        &item_signature_occurrences,
209    );
210    let adjacent_vs_item_jaccard =
211        jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences);
212    let adjacent_vs_signature_jaccard =
213        jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences);
214
215    let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient(
216        containing_range_identifier_occurrences,
217        &item_source_occurrences,
218    );
219    let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient(
220        containing_range_identifier_occurrences,
221        &item_signature_occurrences,
222    );
223    let adjacent_vs_item_weighted_overlap =
224        weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences);
225    let adjacent_vs_signature_weighted_overlap =
226        weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
227
228    let score_components = ScoreInputs {
229        is_same_file,
230        is_referenced_nearby,
231        is_referenced_in_breadcrumb,
232        reference_line_distance,
233        declaration_line_distance,
234        declaration_line_distance_rank,
235        reference_count,
236        same_file_declaration_count,
237        declaration_count,
238        containing_range_vs_item_jaccard,
239        containing_range_vs_signature_jaccard,
240        adjacent_vs_item_jaccard,
241        adjacent_vs_signature_jaccard,
242        containing_range_vs_item_weighted_overlap,
243        containing_range_vs_signature_weighted_overlap,
244        adjacent_vs_item_weighted_overlap,
245        adjacent_vs_signature_weighted_overlap,
246    };
247
248    Some(ScoredSnippet {
249        identifier: identifier.clone(),
250        declaration: declaration,
251        scores: score_components.score(),
252        score_components,
253    })
254}
255
256#[derive(Clone, Debug, Serialize)]
257pub struct ScoreInputs {
258    pub is_same_file: bool,
259    pub is_referenced_nearby: bool,
260    pub is_referenced_in_breadcrumb: bool,
261    pub reference_count: usize,
262    pub same_file_declaration_count: usize,
263    pub declaration_count: usize,
264    pub reference_line_distance: u32,
265    pub declaration_line_distance: u32,
266    pub declaration_line_distance_rank: usize,
267    pub containing_range_vs_item_jaccard: f32,
268    pub containing_range_vs_signature_jaccard: f32,
269    pub adjacent_vs_item_jaccard: f32,
270    pub adjacent_vs_signature_jaccard: f32,
271    pub containing_range_vs_item_weighted_overlap: f32,
272    pub containing_range_vs_signature_weighted_overlap: f32,
273    pub adjacent_vs_item_weighted_overlap: f32,
274    pub adjacent_vs_signature_weighted_overlap: f32,
275}
276
277#[derive(Clone, Debug, Serialize)]
278pub struct Scores {
279    pub signature: f32,
280    pub declaration: f32,
281}
282
283impl ScoreInputs {
284    fn score(&self) -> Scores {
285        // Score related to how likely this is the correct declaration, range 0 to 1
286        let accuracy_score = if self.is_same_file {
287            // TODO: use declaration_line_distance_rank
288            (0.5 / self.same_file_declaration_count as f32)
289        } else {
290            1.0 / self.declaration_count as f32
291        };
292
293        // Score related to the distance between the reference and cursor, range 0 to 1
294        let distance_score = if self.is_referenced_nearby {
295            1.0 / (1.0 + self.reference_line_distance as f32 / 10.0).powf(2.0)
296        } else {
297            // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
298            0.5
299        };
300
301        // For now instead of linear combination, the scores are just multiplied together.
302        let combined_score = 10.0 * accuracy_score * distance_score;
303
304        Scores {
305            signature: combined_score * self.containing_range_vs_signature_weighted_overlap,
306            // declaration score gets boosted both by being multipled by 2 and by there being more
307            // weighted overlap.
308            declaration: 2.0 * combined_score * self.containing_range_vs_item_weighted_overlap,
309        }
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use std::sync::Arc;
317
318    use gpui::{TestAppContext, prelude::*};
319    use indoc::indoc;
320    use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
321    use project::{FakeFs, Project};
322    use serde_json::json;
323    use settings::SettingsStore;
324    use text::ToOffset;
325    use util::path;
326
327    use crate::{EditPredictionExcerptOptions, references_in_excerpt};
328
329    #[gpui::test]
330    async fn test_call_site(cx: &mut TestAppContext) {
331        let (project, index, _rust_lang_id) = init_test(cx).await;
332
333        let buffer = project
334            .update(cx, |project, cx| {
335                let project_path = project.find_project_path("c.rs", cx).unwrap();
336                project.open_buffer(project_path, cx)
337            })
338            .await
339            .unwrap();
340
341        cx.run_until_parked();
342
343        // first process_data call site
344        let cursor_point = language::Point::new(8, 21);
345        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
346        let excerpt = EditPredictionExcerpt::select_from_buffer(
347            cursor_point,
348            &buffer_snapshot,
349            &EditPredictionExcerptOptions {
350                max_bytes: 40,
351                min_bytes: 10,
352                target_before_cursor_over_total_bytes: 0.5,
353                include_parent_signatures: false,
354            },
355        )
356        .unwrap();
357        let excerpt_text = excerpt.text(&buffer_snapshot);
358        let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer_snapshot);
359        let cursor_offset = cursor_point.to_offset(&buffer_snapshot);
360
361        let snippets = cx.update(|cx| {
362            scored_snippets(
363                index,
364                &excerpt,
365                &excerpt_text,
366                references,
367                cursor_offset,
368                &buffer_snapshot,
369                cx,
370            )
371        });
372
373        assert_eq!(snippets.len(), 1);
374        assert_eq!(snippets[0].identifier.name.as_ref(), "process_data");
375        drop(buffer);
376    }
377
378    async fn init_test(
379        cx: &mut TestAppContext,
380    ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
381        cx.update(|cx| {
382            let settings_store = SettingsStore::test(cx);
383            cx.set_global(settings_store);
384            language::init(cx);
385            Project::init_settings(cx);
386        });
387
388        let fs = FakeFs::new(cx.executor());
389        fs.insert_tree(
390            path!("/root"),
391            json!({
392                "a.rs": indoc! {r#"
393                    fn main() {
394                        let x = 1;
395                        let y = 2;
396                        let z = add(x, y);
397                        println!("Result: {}", z);
398                    }
399
400                    fn add(a: i32, b: i32) -> i32 {
401                        a + b
402                    }
403                "#},
404                "b.rs": indoc! {"
405                    pub struct Config {
406                        pub name: String,
407                        pub value: i32,
408                    }
409
410                    impl Config {
411                        pub fn new(name: String, value: i32) -> Self {
412                            Config { name, value }
413                        }
414                    }
415                "},
416                "c.rs": indoc! {r#"
417                    use std::collections::HashMap;
418
419                    fn main() {
420                        let args: Vec<String> = std::env::args().collect();
421                        let data: Vec<i32> = args[1..]
422                            .iter()
423                            .filter_map(|s| s.parse().ok())
424                            .collect();
425                        let result = process_data(data);
426                        println!("{:?}", result);
427                    }
428
429                    fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
430                        let mut counts = HashMap::new();
431                        for value in data {
432                            *counts.entry(value).or_insert(0) += 1;
433                        }
434                        counts
435                    }
436
437                    #[cfg(test)]
438                    mod tests {
439                        use super::*;
440
441                        #[test]
442                        fn test_process_data() {
443                            let data = vec![1, 2, 2, 3];
444                            let result = process_data(data);
445                            assert_eq!(result.get(&2), Some(&2));
446                        }
447                    }
448                "#}
449            }),
450        )
451        .await;
452        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
453        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
454        let lang = rust_lang();
455        let lang_id = lang.id();
456        language_registry.add(Arc::new(lang));
457
458        let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
459        cx.run_until_parked();
460
461        (project, index, lang_id)
462    }
463
464    fn rust_lang() -> Language {
465        Language::new(
466            LanguageConfig {
467                name: "Rust".into(),
468                matcher: LanguageMatcher {
469                    path_suffixes: vec!["rs".to_string()],
470                    ..Default::default()
471                },
472                ..Default::default()
473            },
474            Some(tree_sitter_rust::LANGUAGE.into()),
475        )
476        .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
477        .unwrap()
478        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
479        .unwrap()
480    }
481}