1use collections::HashSet;
2use gpui::{App, Entity};
3use itertools::Itertools as _;
4use language::BufferSnapshot;
5use project::ProjectEntryId;
6use serde::Serialize;
7use std::{collections::HashMap, ops::Range};
8use strum::EnumIter;
9use text::{OffsetRangeExt, Point, ToPoint};
10
11use crate::{
12 Declaration, EditPredictionExcerpt, EditPredictionExcerptText, TreeSitterIndex,
13 outline::Identifier,
14 reference::{Reference, ReferenceRegion},
15 text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
16};
17
18// TODO:
19//
20// * Consider adding declaration_file_count (n)
21
22#[derive(Clone, Debug)]
23pub struct ScoredSnippet {
24 #[allow(dead_code)]
25 pub identifier: Identifier,
26 pub declaration: Declaration,
27 pub score_components: ScoreInputs,
28 pub scores: Scores,
29}
30
31// TODO: Consider having "Concise" style corresponding to `concise_text`
32#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
33pub enum SnippetStyle {
34 Signature,
35 Declaration,
36}
37
38impl ScoredSnippet {
39 /// Returns the score for this snippet with the specified style.
40 pub fn score(&self, style: SnippetStyle) -> f32 {
41 match style {
42 SnippetStyle::Signature => self.scores.signature,
43 SnippetStyle::Declaration => self.scores.declaration,
44 }
45 }
46
47 pub fn size(&self, style: SnippetStyle) -> usize {
48 todo!()
49 }
50
51 pub fn score_density(&self, style: SnippetStyle) -> f32 {
52 self.score(style) / (self.size(style)) as f32
53 }
54}
55
56fn scored_snippets(
57 index: Entity<TreeSitterIndex>,
58 excerpt: &EditPredictionExcerpt,
59 excerpt_text: &EditPredictionExcerptText,
60 identifier_to_references: HashMap<Identifier, Vec<Reference>>,
61 cursor_offset: usize,
62 current_buffer: &BufferSnapshot,
63 cx: &App,
64) -> Vec<ScoredSnippet> {
65 let containing_range_identifier_occurrences =
66 IdentifierOccurrences::within_string(&excerpt_text.body);
67 let cursor_point = cursor_offset.to_point(¤t_buffer);
68
69 // todo! ask michael why we needed this
70 // if let Some(cursor_within_excerpt) = cursor_offset.checked_sub(excerpt.range.start) {
71 // } else {
72 // };
73 let start_point = Point::new(cursor_point.row.saturating_sub(2), 0);
74 let end_point = Point::new(cursor_point.row + 1, 0);
75 let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
76 ¤t_buffer
77 .text_for_range(start_point..end_point)
78 .collect::<String>(),
79 );
80
81 identifier_to_references
82 .into_iter()
83 .flat_map(|(identifier, references)| {
84 let declarations = index
85 .read(cx)
86 // todo! pick a limit
87 .declarations_for_identifier::<16>(&identifier, cx);
88 let declaration_count = declarations.len();
89
90 declarations
91 .iter()
92 .filter_map(|declaration| match declaration {
93 Declaration::Buffer {
94 declaration,
95 buffer,
96 } => {
97 let is_same_file = buffer
98 .read_with(cx, |buffer, _| buffer.remote_id())
99 .is_ok_and(|buffer_id| buffer_id == current_buffer.remote_id());
100
101 if is_same_file {
102 range_intersection(
103 &declaration.item_range.to_offset(¤t_buffer),
104 &excerpt.range,
105 )
106 .is_none()
107 .then(|| {
108 let declaration_line =
109 declaration.item_range.start.to_point(current_buffer).row;
110 (
111 true,
112 (cursor_point.row as i32 - declaration_line as i32).abs()
113 as u32,
114 declaration,
115 )
116 })
117 } else {
118 Some((false, 0, declaration))
119 }
120 }
121 Declaration::File { .. } => {
122 // We can assume that a file declaration is in a different file,
123 // because the current onemust be open
124 Some((false, 0, declaration))
125 }
126 })
127 .sorted_by_key(|&(_, distance, _)| distance)
128 .enumerate()
129 .map(
130 |(
131 declaration_line_distance_rank,
132 (is_same_file, declaration_line_distance, declaration),
133 )| {
134 let same_file_declaration_count =
135 index.read(cx).file_declaration_count(declaration);
136
137 score_snippet(
138 &identifier,
139 &references,
140 declaration.clone(),
141 is_same_file,
142 declaration_line_distance,
143 declaration_line_distance_rank,
144 same_file_declaration_count,
145 declaration_count,
146 &containing_range_identifier_occurrences,
147 &adjacent_identifier_occurrences,
148 cursor_point,
149 current_buffer,
150 cx,
151 )
152 },
153 )
154 .collect::<Vec<_>>()
155 })
156 .flatten()
157 .collect::<Vec<_>>()
158}
159
160// todo! replace with existing util?
161fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
162 let start = a.start.clone().max(b.start.clone());
163 let end = a.end.clone().min(b.end.clone());
164 if start < end {
165 Some(Range { start, end })
166 } else {
167 None
168 }
169}
170
171fn score_snippet(
172 identifier: &Identifier,
173 references: &[Reference],
174 declaration: Declaration,
175 is_same_file: bool,
176 declaration_line_distance: u32,
177 declaration_line_distance_rank: usize,
178 same_file_declaration_count: usize,
179 declaration_count: usize,
180 containing_range_identifier_occurrences: &IdentifierOccurrences,
181 adjacent_identifier_occurrences: &IdentifierOccurrences,
182 cursor: Point,
183 current_buffer: &BufferSnapshot,
184 cx: &App,
185) -> Option<ScoredSnippet> {
186 let is_referenced_nearby = references
187 .iter()
188 .any(|r| r.region == ReferenceRegion::Nearby);
189 let is_referenced_in_breadcrumb = references
190 .iter()
191 .any(|r| r.region == ReferenceRegion::Breadcrumb);
192 let reference_count = references.len();
193 let reference_line_distance = references
194 .iter()
195 .map(|r| {
196 let reference_line = r.range.start.to_point(current_buffer).row as i32;
197 (cursor.row as i32 - reference_line).abs() as u32
198 })
199 .min()
200 .unwrap();
201
202 let item_source_occurrences = IdentifierOccurrences::within_string(&declaration.item_text(cx));
203 let item_signature_occurrences =
204 IdentifierOccurrences::within_string(&declaration.signature_text(cx));
205 let containing_range_vs_item_jaccard = jaccard_similarity(
206 containing_range_identifier_occurrences,
207 &item_source_occurrences,
208 );
209 let containing_range_vs_signature_jaccard = jaccard_similarity(
210 containing_range_identifier_occurrences,
211 &item_signature_occurrences,
212 );
213 let adjacent_vs_item_jaccard =
214 jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences);
215 let adjacent_vs_signature_jaccard =
216 jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences);
217
218 let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient(
219 containing_range_identifier_occurrences,
220 &item_source_occurrences,
221 );
222 let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient(
223 containing_range_identifier_occurrences,
224 &item_signature_occurrences,
225 );
226 let adjacent_vs_item_weighted_overlap =
227 weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences);
228 let adjacent_vs_signature_weighted_overlap =
229 weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
230
231 let score_components = ScoreInputs {
232 is_same_file,
233 is_referenced_nearby,
234 is_referenced_in_breadcrumb,
235 reference_line_distance,
236 declaration_line_distance,
237 declaration_line_distance_rank,
238 reference_count,
239 same_file_declaration_count,
240 declaration_count,
241 containing_range_vs_item_jaccard,
242 containing_range_vs_signature_jaccard,
243 adjacent_vs_item_jaccard,
244 adjacent_vs_signature_jaccard,
245 containing_range_vs_item_weighted_overlap,
246 containing_range_vs_signature_weighted_overlap,
247 adjacent_vs_item_weighted_overlap,
248 adjacent_vs_signature_weighted_overlap,
249 };
250
251 Some(ScoredSnippet {
252 identifier: identifier.clone(),
253 declaration: declaration,
254 scores: score_components.score(),
255 score_components,
256 })
257}
258
259#[derive(Clone, Debug, Serialize)]
260pub struct ScoreInputs {
261 pub is_same_file: bool,
262 pub is_referenced_nearby: bool,
263 pub is_referenced_in_breadcrumb: bool,
264 pub reference_count: usize,
265 pub same_file_declaration_count: usize,
266 pub declaration_count: usize,
267 pub reference_line_distance: u32,
268 pub declaration_line_distance: u32,
269 pub declaration_line_distance_rank: usize,
270 pub containing_range_vs_item_jaccard: f32,
271 pub containing_range_vs_signature_jaccard: f32,
272 pub adjacent_vs_item_jaccard: f32,
273 pub adjacent_vs_signature_jaccard: f32,
274 pub containing_range_vs_item_weighted_overlap: f32,
275 pub containing_range_vs_signature_weighted_overlap: f32,
276 pub adjacent_vs_item_weighted_overlap: f32,
277 pub adjacent_vs_signature_weighted_overlap: f32,
278}
279
280#[derive(Clone, Debug, Serialize)]
281pub struct Scores {
282 pub signature: f32,
283 pub declaration: f32,
284}
285
286impl ScoreInputs {
287 fn score(&self) -> Scores {
288 // Score related to how likely this is the correct declaration, range 0 to 1
289 let accuracy_score = if self.is_same_file {
290 // TODO: use declaration_line_distance_rank
291 (0.5 / self.same_file_declaration_count as f32)
292 } else {
293 1.0 / self.declaration_count as f32
294 };
295
296 // Score related to the distance between the reference and cursor, range 0 to 1
297 let distance_score = if self.is_referenced_nearby {
298 1.0 / (1.0 + self.reference_line_distance as f32 / 10.0).powf(2.0)
299 } else {
300 // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
301 0.5
302 };
303
304 // For now instead of linear combination, the scores are just multiplied together.
305 let combined_score = 10.0 * accuracy_score * distance_score;
306
307 Scores {
308 signature: combined_score * self.containing_range_vs_signature_weighted_overlap,
309 // declaration score gets boosted both by being multipled by 2 and by there being more
310 // weighted overlap.
311 declaration: 2.0 * combined_score * self.containing_range_vs_item_weighted_overlap,
312 }
313 }
314}
315
316#[cfg(test)]
317mod tests {
318 use super::*;
319 use std::sync::Arc;
320
321 use gpui::{TestAppContext, prelude::*};
322 use indoc::indoc;
323 use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
324 use project::{FakeFs, Project};
325 use serde_json::json;
326 use settings::SettingsStore;
327 use text::ToOffset;
328 use util::path;
329
330 use crate::{
331 EditPredictionExcerptOptions, references_in_excerpt, tree_sitter_index::TreeSitterIndex,
332 };
333
334 #[gpui::test]
335 async fn test_call_site(cx: &mut TestAppContext) {
336 let (project, index, _rust_lang_id) = init_test(cx).await;
337
338 let buffer = project
339 .update(cx, |project, cx| {
340 let project_path = project.find_project_path("c.rs", cx).unwrap();
341 project.open_buffer(project_path, cx)
342 })
343 .await
344 .unwrap();
345
346 cx.run_until_parked();
347
348 // first process_data call site
349 let cursor_point = language::Point::new(8, 21);
350 let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
351 let excerpt = EditPredictionExcerpt::select_from_buffer(
352 cursor_point,
353 &buffer_snapshot,
354 &EditPredictionExcerptOptions {
355 max_bytes: 40,
356 min_bytes: 10,
357 target_before_cursor_over_total_bytes: 0.5,
358 include_parent_signatures: false,
359 },
360 )
361 .unwrap();
362 let excerpt_text = excerpt.text(&buffer_snapshot);
363 let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer_snapshot);
364 let cursor_offset = cursor_point.to_offset(&buffer_snapshot);
365
366 let snippets = cx.update(|cx| {
367 scored_snippets(
368 index,
369 &excerpt,
370 &excerpt_text,
371 references,
372 cursor_offset,
373 &buffer_snapshot,
374 cx,
375 )
376 });
377
378 assert_eq!(snippets.len(), 1);
379 assert_eq!(snippets[0].identifier.name.as_ref(), "process_data");
380 drop(buffer);
381 }
382
383 async fn init_test(
384 cx: &mut TestAppContext,
385 ) -> (Entity<Project>, Entity<TreeSitterIndex>, LanguageId) {
386 cx.update(|cx| {
387 let settings_store = SettingsStore::test(cx);
388 cx.set_global(settings_store);
389 language::init(cx);
390 Project::init_settings(cx);
391 });
392
393 let fs = FakeFs::new(cx.executor());
394 fs.insert_tree(
395 path!("/root"),
396 json!({
397 "a.rs": indoc! {r#"
398 fn main() {
399 let x = 1;
400 let y = 2;
401 let z = add(x, y);
402 println!("Result: {}", z);
403 }
404
405 fn add(a: i32, b: i32) -> i32 {
406 a + b
407 }
408 "#},
409 "b.rs": indoc! {"
410 pub struct Config {
411 pub name: String,
412 pub value: i32,
413 }
414
415 impl Config {
416 pub fn new(name: String, value: i32) -> Self {
417 Config { name, value }
418 }
419 }
420 "},
421 "c.rs": indoc! {r#"
422 use std::collections::HashMap;
423
424 fn main() {
425 let args: Vec<String> = std::env::args().collect();
426 let data: Vec<i32> = args[1..]
427 .iter()
428 .filter_map(|s| s.parse().ok())
429 .collect();
430 let result = process_data(data);
431 println!("{:?}", result);
432 }
433
434 fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
435 let mut counts = HashMap::new();
436 for value in data {
437 *counts.entry(value).or_insert(0) += 1;
438 }
439 counts
440 }
441
442 #[cfg(test)]
443 mod tests {
444 use super::*;
445
446 #[test]
447 fn test_process_data() {
448 let data = vec![1, 2, 2, 3];
449 let result = process_data(data);
450 assert_eq!(result.get(&2), Some(&2));
451 }
452 }
453 "#}
454 }),
455 )
456 .await;
457 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
458 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
459 let lang = rust_lang();
460 let lang_id = lang.id();
461 language_registry.add(Arc::new(lang));
462
463 let index = cx.new(|cx| TreeSitterIndex::new(&project, cx));
464 cx.run_until_parked();
465
466 (project, index, lang_id)
467 }
468
469 fn rust_lang() -> Language {
470 Language::new(
471 LanguageConfig {
472 name: "Rust".into(),
473 matcher: LanguageMatcher {
474 path_suffixes: vec!["rs".to_string()],
475 ..Default::default()
476 },
477 ..Default::default()
478 },
479 Some(tree_sitter_rust::LANGUAGE.into()),
480 )
481 .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
482 .unwrap()
483 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
484 .unwrap()
485 }
486}