1use gpui::{App, Entity};
2use itertools::Itertools as _;
3use language::BufferSnapshot;
4use serde::Serialize;
5use std::{collections::HashMap, ops::Range};
6use strum::EnumIter;
7use text::{OffsetRangeExt, Point, ToPoint};
8
9use crate::{
10 Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier, SyntaxIndex,
11 reference::{Reference, ReferenceRegion},
12 text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
13};
14
15// TODO:
16//
17// * Consider adding declaration_file_count (n)
18
19#[derive(Clone, Debug)]
20pub struct ScoredSnippet {
21 #[allow(dead_code)]
22 pub identifier: Identifier,
23 pub declaration: Declaration,
24 pub score_components: ScoreInputs,
25 pub scores: Scores,
26}
27
28// TODO: Consider having "Concise" style corresponding to `concise_text`
29#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
30pub enum SnippetStyle {
31 Signature,
32 Declaration,
33}
34
35impl ScoredSnippet {
36 /// Returns the score for this snippet with the specified style.
37 pub fn score(&self, style: SnippetStyle) -> f32 {
38 match style {
39 SnippetStyle::Signature => self.scores.signature,
40 SnippetStyle::Declaration => self.scores.declaration,
41 }
42 }
43
44 pub fn size(&self, style: SnippetStyle) -> usize {
45 todo!()
46 }
47
48 pub fn score_density(&self, style: SnippetStyle) -> f32 {
49 self.score(style) / (self.size(style)) as f32
50 }
51}
52
53fn scored_snippets(
54 index: Entity<SyntaxIndex>,
55 excerpt: &EditPredictionExcerpt,
56 excerpt_text: &EditPredictionExcerptText,
57 identifier_to_references: HashMap<Identifier, Vec<Reference>>,
58 cursor_offset: usize,
59 current_buffer: &BufferSnapshot,
60 cx: &App,
61) -> Vec<ScoredSnippet> {
62 let containing_range_identifier_occurrences =
63 IdentifierOccurrences::within_string(&excerpt_text.body);
64 let cursor_point = cursor_offset.to_point(¤t_buffer);
65
66 let start_point = Point::new(cursor_point.row.saturating_sub(2), 0);
67 let end_point = Point::new(cursor_point.row + 1, 0);
68 let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
69 ¤t_buffer
70 .text_for_range(start_point..end_point)
71 .collect::<String>(),
72 );
73
74 identifier_to_references
75 .into_iter()
76 .flat_map(|(identifier, references)| {
77 let declarations = index
78 .read(cx)
79 // todo! pick a limit
80 .declarations_for_identifier::<16>(&identifier, cx);
81 let declaration_count = declarations.len();
82
83 declarations
84 .iter()
85 .filter_map(|declaration| match declaration {
86 Declaration::Buffer {
87 declaration: buffer_declaration,
88 buffer,
89 } => {
90 let is_same_file = buffer
91 .read_with(cx, |buffer, _| buffer.remote_id())
92 .is_ok_and(|buffer_id| buffer_id == current_buffer.remote_id());
93
94 if is_same_file {
95 range_intersection(
96 &buffer_declaration.item_range.to_offset(¤t_buffer),
97 &excerpt.range,
98 )
99 .is_none()
100 .then(|| {
101 let declaration_line = buffer_declaration
102 .item_range
103 .start
104 .to_point(current_buffer)
105 .row;
106 (
107 true,
108 (cursor_point.row as i32 - declaration_line as i32).abs()
109 as u32,
110 declaration,
111 )
112 })
113 } else {
114 Some((false, 0, declaration))
115 }
116 }
117 Declaration::File { .. } => {
118 // We can assume that a file declaration is in a different file,
119 // because the current one must be open
120 Some((false, 0, declaration))
121 }
122 })
123 .sorted_by_key(|&(_, distance, _)| distance)
124 .enumerate()
125 .map(
126 |(
127 declaration_line_distance_rank,
128 (is_same_file, declaration_line_distance, declaration),
129 )| {
130 let same_file_declaration_count =
131 index.read(cx).file_declaration_count(declaration);
132
133 score_snippet(
134 &identifier,
135 &references,
136 declaration.clone(),
137 is_same_file,
138 declaration_line_distance,
139 declaration_line_distance_rank,
140 same_file_declaration_count,
141 declaration_count,
142 &containing_range_identifier_occurrences,
143 &adjacent_identifier_occurrences,
144 cursor_point,
145 current_buffer,
146 cx,
147 )
148 },
149 )
150 .collect::<Vec<_>>()
151 })
152 .flatten()
153 .collect::<Vec<_>>()
154}
155
156// todo! replace with existing util?
157fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
158 let start = a.start.clone().max(b.start.clone());
159 let end = a.end.clone().min(b.end.clone());
160 if start < end {
161 Some(Range { start, end })
162 } else {
163 None
164 }
165}
166
167fn score_snippet(
168 identifier: &Identifier,
169 references: &[Reference],
170 declaration: Declaration,
171 is_same_file: bool,
172 declaration_line_distance: u32,
173 declaration_line_distance_rank: usize,
174 same_file_declaration_count: usize,
175 declaration_count: usize,
176 containing_range_identifier_occurrences: &IdentifierOccurrences,
177 adjacent_identifier_occurrences: &IdentifierOccurrences,
178 cursor: Point,
179 current_buffer: &BufferSnapshot,
180 cx: &App,
181) -> Option<ScoredSnippet> {
182 let is_referenced_nearby = references
183 .iter()
184 .any(|r| r.region == ReferenceRegion::Nearby);
185 let is_referenced_in_breadcrumb = references
186 .iter()
187 .any(|r| r.region == ReferenceRegion::Breadcrumb);
188 let reference_count = references.len();
189 let reference_line_distance = references
190 .iter()
191 .map(|r| {
192 let reference_line = r.range.start.to_point(current_buffer).row as i32;
193 (cursor.row as i32 - reference_line).abs() as u32
194 })
195 .min()
196 .unwrap();
197
198 let item_source_occurrences =
199 IdentifierOccurrences::within_string(&declaration.item_text(cx).0);
200 let item_signature_occurrences =
201 IdentifierOccurrences::within_string(&declaration.signature_text(cx).0);
202 let containing_range_vs_item_jaccard = jaccard_similarity(
203 containing_range_identifier_occurrences,
204 &item_source_occurrences,
205 );
206 let containing_range_vs_signature_jaccard = jaccard_similarity(
207 containing_range_identifier_occurrences,
208 &item_signature_occurrences,
209 );
210 let adjacent_vs_item_jaccard =
211 jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences);
212 let adjacent_vs_signature_jaccard =
213 jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences);
214
215 let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient(
216 containing_range_identifier_occurrences,
217 &item_source_occurrences,
218 );
219 let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient(
220 containing_range_identifier_occurrences,
221 &item_signature_occurrences,
222 );
223 let adjacent_vs_item_weighted_overlap =
224 weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences);
225 let adjacent_vs_signature_weighted_overlap =
226 weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
227
228 let score_components = ScoreInputs {
229 is_same_file,
230 is_referenced_nearby,
231 is_referenced_in_breadcrumb,
232 reference_line_distance,
233 declaration_line_distance,
234 declaration_line_distance_rank,
235 reference_count,
236 same_file_declaration_count,
237 declaration_count,
238 containing_range_vs_item_jaccard,
239 containing_range_vs_signature_jaccard,
240 adjacent_vs_item_jaccard,
241 adjacent_vs_signature_jaccard,
242 containing_range_vs_item_weighted_overlap,
243 containing_range_vs_signature_weighted_overlap,
244 adjacent_vs_item_weighted_overlap,
245 adjacent_vs_signature_weighted_overlap,
246 };
247
248 Some(ScoredSnippet {
249 identifier: identifier.clone(),
250 declaration: declaration,
251 scores: score_components.score(),
252 score_components,
253 })
254}
255
256#[derive(Clone, Debug, Serialize)]
257pub struct ScoreInputs {
258 pub is_same_file: bool,
259 pub is_referenced_nearby: bool,
260 pub is_referenced_in_breadcrumb: bool,
261 pub reference_count: usize,
262 pub same_file_declaration_count: usize,
263 pub declaration_count: usize,
264 pub reference_line_distance: u32,
265 pub declaration_line_distance: u32,
266 pub declaration_line_distance_rank: usize,
267 pub containing_range_vs_item_jaccard: f32,
268 pub containing_range_vs_signature_jaccard: f32,
269 pub adjacent_vs_item_jaccard: f32,
270 pub adjacent_vs_signature_jaccard: f32,
271 pub containing_range_vs_item_weighted_overlap: f32,
272 pub containing_range_vs_signature_weighted_overlap: f32,
273 pub adjacent_vs_item_weighted_overlap: f32,
274 pub adjacent_vs_signature_weighted_overlap: f32,
275}
276
277#[derive(Clone, Debug, Serialize)]
278pub struct Scores {
279 pub signature: f32,
280 pub declaration: f32,
281}
282
283impl ScoreInputs {
284 fn score(&self) -> Scores {
285 // Score related to how likely this is the correct declaration, range 0 to 1
286 let accuracy_score = if self.is_same_file {
287 // TODO: use declaration_line_distance_rank
288 (0.5 / self.same_file_declaration_count as f32)
289 } else {
290 1.0 / self.declaration_count as f32
291 };
292
293 // Score related to the distance between the reference and cursor, range 0 to 1
294 let distance_score = if self.is_referenced_nearby {
295 1.0 / (1.0 + self.reference_line_distance as f32 / 10.0).powf(2.0)
296 } else {
297 // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
298 0.5
299 };
300
301 // For now instead of linear combination, the scores are just multiplied together.
302 let combined_score = 10.0 * accuracy_score * distance_score;
303
304 Scores {
305 signature: combined_score * self.containing_range_vs_signature_weighted_overlap,
306 // declaration score gets boosted both by being multipled by 2 and by there being more
307 // weighted overlap.
308 declaration: 2.0 * combined_score * self.containing_range_vs_item_weighted_overlap,
309 }
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316 use std::sync::Arc;
317
318 use gpui::{TestAppContext, prelude::*};
319 use indoc::indoc;
320 use language::{Language, LanguageConfig, LanguageId, LanguageMatcher, tree_sitter_rust};
321 use project::{FakeFs, Project};
322 use serde_json::json;
323 use settings::SettingsStore;
324 use text::ToOffset;
325 use util::path;
326
327 use crate::{EditPredictionExcerptOptions, references_in_excerpt};
328
329 #[gpui::test]
330 async fn test_call_site(cx: &mut TestAppContext) {
331 let (project, index, _rust_lang_id) = init_test(cx).await;
332
333 let buffer = project
334 .update(cx, |project, cx| {
335 let project_path = project.find_project_path("c.rs", cx).unwrap();
336 project.open_buffer(project_path, cx)
337 })
338 .await
339 .unwrap();
340
341 cx.run_until_parked();
342
343 // first process_data call site
344 let cursor_point = language::Point::new(8, 21);
345 let buffer_snapshot = buffer.read_with(cx, |buffer, _| buffer.snapshot());
346 let excerpt = EditPredictionExcerpt::select_from_buffer(
347 cursor_point,
348 &buffer_snapshot,
349 &EditPredictionExcerptOptions {
350 max_bytes: 40,
351 min_bytes: 10,
352 target_before_cursor_over_total_bytes: 0.5,
353 include_parent_signatures: false,
354 },
355 )
356 .unwrap();
357 let excerpt_text = excerpt.text(&buffer_snapshot);
358 let references = references_in_excerpt(&excerpt, &excerpt_text, &buffer_snapshot);
359 let cursor_offset = cursor_point.to_offset(&buffer_snapshot);
360
361 let snippets = cx.update(|cx| {
362 scored_snippets(
363 index,
364 &excerpt,
365 &excerpt_text,
366 references,
367 cursor_offset,
368 &buffer_snapshot,
369 cx,
370 )
371 });
372
373 assert_eq!(snippets.len(), 1);
374 assert_eq!(snippets[0].identifier.name.as_ref(), "process_data");
375 drop(buffer);
376 }
377
378 async fn init_test(
379 cx: &mut TestAppContext,
380 ) -> (Entity<Project>, Entity<SyntaxIndex>, LanguageId) {
381 cx.update(|cx| {
382 let settings_store = SettingsStore::test(cx);
383 cx.set_global(settings_store);
384 language::init(cx);
385 Project::init_settings(cx);
386 });
387
388 let fs = FakeFs::new(cx.executor());
389 fs.insert_tree(
390 path!("/root"),
391 json!({
392 "a.rs": indoc! {r#"
393 fn main() {
394 let x = 1;
395 let y = 2;
396 let z = add(x, y);
397 println!("Result: {}", z);
398 }
399
400 fn add(a: i32, b: i32) -> i32 {
401 a + b
402 }
403 "#},
404 "b.rs": indoc! {"
405 pub struct Config {
406 pub name: String,
407 pub value: i32,
408 }
409
410 impl Config {
411 pub fn new(name: String, value: i32) -> Self {
412 Config { name, value }
413 }
414 }
415 "},
416 "c.rs": indoc! {r#"
417 use std::collections::HashMap;
418
419 fn main() {
420 let args: Vec<String> = std::env::args().collect();
421 let data: Vec<i32> = args[1..]
422 .iter()
423 .filter_map(|s| s.parse().ok())
424 .collect();
425 let result = process_data(data);
426 println!("{:?}", result);
427 }
428
429 fn process_data(data: Vec<i32>) -> HashMap<i32, usize> {
430 let mut counts = HashMap::new();
431 for value in data {
432 *counts.entry(value).or_insert(0) += 1;
433 }
434 counts
435 }
436
437 #[cfg(test)]
438 mod tests {
439 use super::*;
440
441 #[test]
442 fn test_process_data() {
443 let data = vec![1, 2, 2, 3];
444 let result = process_data(data);
445 assert_eq!(result.get(&2), Some(&2));
446 }
447 }
448 "#}
449 }),
450 )
451 .await;
452 let project = Project::test(fs.clone(), [path!("/root").as_ref()], cx).await;
453 let language_registry = project.read_with(cx, |project, _| project.languages().clone());
454 let lang = rust_lang();
455 let lang_id = lang.id();
456 language_registry.add(Arc::new(lang));
457
458 let index = cx.new(|cx| SyntaxIndex::new(&project, cx));
459 cx.run_until_parked();
460
461 (project, index, lang_id)
462 }
463
464 fn rust_lang() -> Language {
465 Language::new(
466 LanguageConfig {
467 name: "Rust".into(),
468 matcher: LanguageMatcher {
469 path_suffixes: vec!["rs".to_string()],
470 ..Default::default()
471 },
472 ..Default::default()
473 },
474 Some(tree_sitter_rust::LANGUAGE.into()),
475 )
476 .with_highlights_query(include_str!("../../languages/src/rust/highlights.scm"))
477 .unwrap()
478 .with_outline_query(include_str!("../../languages/src/rust/outline.scm"))
479 .unwrap()
480 }
481}