scored_declaration.rs

  1use collections::HashSet;
  2use gpui::{App, Entity};
  3use itertools::Itertools as _;
  4use language::BufferSnapshot;
  5use project::ProjectEntryId;
  6use serde::Serialize;
  7use std::{collections::HashMap, ops::Range};
  8use strum::EnumIter;
  9use text::{OffsetRangeExt, Point, ToPoint};
 10
 11use crate::{
 12    Declaration, EditPredictionExcerpt, EditPredictionExcerptText, TreeSitterIndex,
 13    outline::Identifier,
 14    reference::{Reference, ReferenceRegion},
 15    text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
 16};
 17
 18// TODO:
 19//
 20// * Consider adding declaration_file_count (n)
 21
 22#[derive(Clone, Debug)]
 23pub struct ScoredSnippet {
 24    #[allow(dead_code)]
 25    pub identifier: Identifier,
 26    pub declaration: Declaration,
 27    pub score_components: ScoreInputs,
 28    pub scores: Scores,
 29}
 30
 31// TODO: Consider having "Concise" style corresponding to `concise_text`
 32#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
 33pub enum SnippetStyle {
 34    Signature,
 35    Declaration,
 36}
 37
 38impl ScoredSnippet {
 39    /// Returns the score for this snippet with the specified style.
 40    pub fn score(&self, style: SnippetStyle) -> f32 {
 41        match style {
 42            SnippetStyle::Signature => self.scores.signature,
 43            SnippetStyle::Declaration => self.scores.declaration,
 44        }
 45    }
 46
 47    pub fn size(&self, style: SnippetStyle) -> usize {
 48        todo!()
 49    }
 50
 51    pub fn score_density(&self, style: SnippetStyle) -> f32 {
 52        self.score(style) / (self.size(style)) as f32
 53    }
 54}
 55
 56fn scored_snippets(
 57    index: Entity<TreeSitterIndex>,
 58    excerpt: &EditPredictionExcerpt,
 59    excerpt_text: &EditPredictionExcerptText,
 60    identifier_to_references: HashMap<Identifier, Vec<Reference>>,
 61    cursor_offset: usize,
 62    current_buffer: &BufferSnapshot,
 63    cx: &App,
 64) -> Vec<ScoredSnippet> {
 65    let containing_range_identifier_occurrences =
 66        IdentifierOccurrences::within_string(&excerpt_text.body);
 67    let cursor_point = cursor_offset.to_point(&current_buffer);
 68
 69    // todo! ask michael why we needed this
 70    // if let Some(cursor_within_excerpt) = cursor_offset.checked_sub(excerpt.range.start) {
 71    // } else {
 72    // };
 73    let start_point = Point::new(cursor_point.row.saturating_sub(2), 0);
 74    let end_point = Point::new(cursor_point.row + 1, 0);
 75    let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
 76        &current_buffer
 77            .text_for_range(start_point..end_point)
 78            .collect::<String>(),
 79    );
 80
 81    identifier_to_references
 82        .into_iter()
 83        .flat_map(|(identifier, references)| {
 84            let declarations = index
 85                .read(cx)
 86                // todo! pick a limit
 87                .declarations_for_identifier::<16>(&identifier, cx);
 88            let declaration_count = declarations.len();
 89
 90            declarations
 91                .iter()
 92                .filter_map(|declaration| match declaration {
 93                    Declaration::Buffer {
 94                        declaration,
 95                        buffer,
 96                    } => {
 97                        let is_same_file = buffer
 98                            .read_with(cx, |buffer, _| buffer.remote_id())
 99                            .is_ok_and(|buffer_id| buffer_id == current_buffer.remote_id());
100
101                        if is_same_file {
102                            range_intersection(
103                                &declaration.item_range.to_offset(&current_buffer),
104                                &excerpt.range,
105                            )
106                            .is_none()
107                            .then(|| {
108                                let declaration_line =
109                                    declaration.item_range.start.to_point(current_buffer).row;
110                                (
111                                    true,
112                                    (cursor_point.row as i32 - declaration_line as i32).abs()
113                                        as u32,
114                                    declaration,
115                                )
116                            })
117                        } else {
118                            Some((false, 0, declaration))
119                        }
120                    }
121                    Declaration::File { .. } => {
122                        // We can assume that a file declaration is in a different file,
123                        // because the current onemust be open
124                        Some((false, 0, declaration))
125                    }
126                })
127                .sorted_by_key(|&(_, distance, _)| distance)
128                .enumerate()
129                .map(
130                    |(
131                        declaration_line_distance_rank,
132                        (is_same_file, declaration_line_distance, declaration),
133                    )| {
134                        let same_file_declaration_count =
135                            index.read(cx).file_declaration_count(declaration);
136
137                        score_snippet(
138                            &identifier,
139                            &references,
140                            declaration.clone(),
141                            is_same_file,
142                            declaration_line_distance,
143                            declaration_line_distance_rank,
144                            same_file_declaration_count,
145                            declaration_count,
146                            &containing_range_identifier_occurrences,
147                            &adjacent_identifier_occurrences,
148                            cursor_point,
149                            current_buffer,
150                            cx,
151                        )
152                    },
153                )
154                .collect::<Vec<_>>()
155        })
156        .flatten()
157        .collect::<Vec<_>>()
158}
159
160// todo! replace with existing util?
161fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
162    let start = a.start.clone().max(b.start.clone());
163    let end = a.end.clone().min(b.end.clone());
164    if start < end {
165        Some(Range { start, end })
166    } else {
167        None
168    }
169}
170
171fn score_snippet(
172    identifier: &Identifier,
173    references: &[Reference],
174    declaration: Declaration,
175    is_same_file: bool,
176    declaration_line_distance: u32,
177    declaration_line_distance_rank: usize,
178    same_file_declaration_count: usize,
179    declaration_count: usize,
180    containing_range_identifier_occurrences: &IdentifierOccurrences,
181    adjacent_identifier_occurrences: &IdentifierOccurrences,
182    cursor: Point,
183    current_buffer: &BufferSnapshot,
184    cx: &App,
185) -> Option<ScoredSnippet> {
186    let is_referenced_nearby = references
187        .iter()
188        .any(|r| r.region == ReferenceRegion::Nearby);
189    let is_referenced_in_breadcrumb = references
190        .iter()
191        .any(|r| r.region == ReferenceRegion::Breadcrumb);
192    let reference_count = references.len();
193    let reference_line_distance = references
194        .iter()
195        .map(|r| {
196            let reference_line = r.range.start.to_point(current_buffer).row as i32;
197            (cursor.row as i32 - reference_line).abs() as u32
198        })
199        .min()
200        .unwrap();
201
202    let item_source_occurrences = IdentifierOccurrences::within_string(&declaration.item_text(cx));
203    let item_signature_occurrences =
204        IdentifierOccurrences::within_string(&declaration.signature_text(cx));
205    let containing_range_vs_item_jaccard = jaccard_similarity(
206        containing_range_identifier_occurrences,
207        &item_source_occurrences,
208    );
209    let containing_range_vs_signature_jaccard = jaccard_similarity(
210        containing_range_identifier_occurrences,
211        &item_signature_occurrences,
212    );
213    let adjacent_vs_item_jaccard =
214        jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences);
215    let adjacent_vs_signature_jaccard =
216        jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences);
217
218    let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient(
219        containing_range_identifier_occurrences,
220        &item_source_occurrences,
221    );
222    let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient(
223        containing_range_identifier_occurrences,
224        &item_signature_occurrences,
225    );
226    let adjacent_vs_item_weighted_overlap =
227        weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences);
228    let adjacent_vs_signature_weighted_overlap =
229        weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
230
231    let score_components = ScoreInputs {
232        is_same_file,
233        is_referenced_nearby,
234        is_referenced_in_breadcrumb,
235        reference_line_distance,
236        declaration_line_distance,
237        declaration_line_distance_rank,
238        reference_count,
239        same_file_declaration_count,
240        declaration_count,
241        containing_range_vs_item_jaccard,
242        containing_range_vs_signature_jaccard,
243        adjacent_vs_item_jaccard,
244        adjacent_vs_signature_jaccard,
245        containing_range_vs_item_weighted_overlap,
246        containing_range_vs_signature_weighted_overlap,
247        adjacent_vs_item_weighted_overlap,
248        adjacent_vs_signature_weighted_overlap,
249    };
250
251    Some(ScoredSnippet {
252        identifier: identifier.clone(),
253        declaration: declaration,
254        scores: score_components.score(),
255        score_components,
256    })
257}
258
259#[derive(Clone, Debug, Serialize)]
260pub struct ScoreInputs {
261    pub is_same_file: bool,
262    pub is_referenced_nearby: bool,
263    pub is_referenced_in_breadcrumb: bool,
264    pub reference_count: usize,
265    pub same_file_declaration_count: usize,
266    pub declaration_count: usize,
267    pub reference_line_distance: u32,
268    pub declaration_line_distance: u32,
269    pub declaration_line_distance_rank: usize,
270    pub containing_range_vs_item_jaccard: f32,
271    pub containing_range_vs_signature_jaccard: f32,
272    pub adjacent_vs_item_jaccard: f32,
273    pub adjacent_vs_signature_jaccard: f32,
274    pub containing_range_vs_item_weighted_overlap: f32,
275    pub containing_range_vs_signature_weighted_overlap: f32,
276    pub adjacent_vs_item_weighted_overlap: f32,
277    pub adjacent_vs_signature_weighted_overlap: f32,
278}
279
280#[derive(Clone, Debug, Serialize)]
281pub struct Scores {
282    pub signature: f32,
283    pub declaration: f32,
284}
285
286impl ScoreInputs {
287    fn score(&self) -> Scores {
288        // Score related to how likely this is the correct declaration, range 0 to 1
289        let accuracy_score = if self.is_same_file {
290            // TODO: use declaration_line_distance_rank
291            (0.5 / self.same_file_declaration_count as f32)
292        } else {
293            1.0 / self.declaration_count as f32
294        };
295
296        // Score related to the distance between the reference and cursor, range 0 to 1
297        let distance_score = if self.is_referenced_nearby {
298            1.0 / (1.0 + self.reference_line_distance as f32 / 10.0).powf(2.0)
299        } else {
300            // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
301            0.5
302        };
303
304        // For now instead of linear combination, the scores are just multiplied together.
305        let combined_score = 10.0 * accuracy_score * distance_score;
306
307        Scores {
308            signature: combined_score * self.containing_range_vs_signature_weighted_overlap,
309            // declaration score gets boosted both by being multipled by 2 and by there being more
310            // weighted overlap.
311            declaration: 2.0 * combined_score * self.containing_range_vs_item_weighted_overlap,
312        }
313    }
314}
315
316#[cfg(test)]
317mod tests {
318    use super::*;
319    use std::sync::Arc;
320
321    use gpui::{TestAppContext, prelude::*};
322    use indoc::indoc;
323    use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
324    use project::{FakeFs, Project};
325    use serde_json::json;
326    use settings::SettingsStore;
327    use text::ToOffset;
328    use util::path;
329
330    use crate::{
331        EditPredictionExcerptOptions, references_in_excerpt, tree_sitter_index::TreeSitterIndex,
332    };
333
334    #[gpui::test]
335    async fn test_call_site(cx: &mut TestAppContext) {
336        let (project, index, _rust_lang_id) = init_test(cx).await;
337
338        let buffer = project
339            .update(cx, |project, cx| {
340                let project_path = project.find_project_path("c.rs", cx).unwrap();
341                project.open_buffer(project_path, cx)
342            })
343            .await
344            .unwrap();
345
346        cx.run_until_parked();
347
348        // first process_data call site
349        let cursor_point = language::Point::new(8, 21);
350        let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
351        let excerpt = EditPredictionExcerpt::select_from_buffer(
352            cursor_point,
353            &buffer_snapshot,
354            &EditPredictionExcerptOptions {
355                max_bytes: 40,
356                min_bytes: 10,
357                target_before_cursor_over_total_bytes: 0.5,
358                include_parent_signatures: false,
359            },
360        )
361        .unwrap();
362        let excerpt_text = excerpt.text(&buffer_snapshot);
363        let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer_snapshot);
364        let cursor_offset = cursor_point.to_offset(&buffer_snapshot);
365
366        let snippets = cx.update(|cx| {
367            scored_snippets(
368                index,
369                &excerpt,
370                &excerpt_text,
371                references,
372                cursor_offset,
373                &buffer_snapshot,
374                cx,
375            )
376        });
377
378        assert_eq!(snippets.len(), 1);
379        assert_eq!(snippets[0].identifier.name.as_ref(), "process_data");
380        drop(buffer);
381    }
382
383    async fn init_test(
384        cx: &mut TestAppContext,
385    ) -> (Entity<Project>, Entity<TreeSitterIndex>, LanguageId) {
386        cx.update(|cx| {
387            let settings_store = SettingsStore::test(cx);
388            cx.set_global(settings_store);
389            language::init(cx);
390            Project::init_settings(cx);
391        });
392
393        let fs = FakeFs::new(cx.executor());
394        fs.insert_tree(
395            path!("/root"),
396            json!({
397                "a.rs": indoc! {r#"
398                    fn main() {
399                        let x = 1;
400                        let y = 2;
401                        let z = add(x, y);
402                        println!("Result: {}", z);
403                    }
404
405                    fn add(a: i32, b: i32) -> i32 {
406                        a + b
407                    }
408                "#},
409                "b.rs": indoc! {"
410                    pub struct Config {
411                        pub name: String,
412                        pub value: i32,
413                    }
414
415                    impl Config {
416                        pub fn new(name: String, value: i32) -> Self {
417                            Config { name, value }
418                        }
419                    }
420                "},
421                "c.rs": indoc! {r#"
422                    use std::collections::HashMap;
423
424                    fn main() {
425                        let args: Vec<String> = std::env::args().collect();
426                        let data: Vec<i32> = args[1..]
427                            .iter()
428                            .filter_map(|s| s.parse().ok())
429                            .collect();
430                        let result = process_data(data);
431                        println!("{:?}", result);
432                    }
433
434                    fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
435                        let mut counts = HashMap::new();
436                        for value in data {
437                            *counts.entry(value).or_insert(0) += 1;
438                        }
439                        counts
440                    }
441
442                    #[cfg(test)]
443                    mod tests {
444                        use super::*;
445
446                        #[test]
447                        fn test_process_data() {
448                            let data = vec![1, 2, 2, 3];
449                            let result = process_data(data);
450                            assert_eq!(result.get(&2), Some(&2));
451                        }
452                    }
453                "#}
454            }),
455        )
456        .await;
457        let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
458        let language_registry = project.read_with(cx, |project, _| project.languages().clone());
459        let lang = rust_lang();
460        let lang_id = lang.id();
461        language_registry.add(Arc::new(lang));
462
463        let index = cx.new(|cx| TreeSitterIndex::new(&project, cx));
464        cx.run_until_parked();
465
466        (project, index, lang_id)
467    }
468
469    fn rust_lang() -> Language {
470        Language::new(
471            LanguageConfig {
472                name: "Rust".into(),
473                matcher: LanguageMatcher {
474                    path_suffixes: vec!["rs".to_string()],
475                    ..Default::default()
476                },
477                ..Default::default()
478            },
479            Some(tree_sitter_rust::LANGUAGE.into()),
480        )
481        .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
482        .unwrap()
483        .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
484        .unwrap()
485    }
486}