1use itertools::Itertools as _;
2use language::BufferSnapshot;
3use ordered_float::OrderedFloat;
4use serde::Serialize;
5use std::{collections::HashMap, ops::Range};
6use strum::EnumIter;
7use text::{OffsetRangeExt, Point, ToPoint};
8
9use crate::{
10 Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier,
11 reference::{Reference, ReferenceRegion},
12 syntax_index::SyntaxIndexState,
13 text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
14};
15
16const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
17
18// TODO:
19//
20// * Consider adding declaration_file_count
21
22#[derive(Clone, Debug)]
23pub struct ScoredSnippet {
24 pub identifier: Identifier,
25 pub declaration: Declaration,
26 pub score_components: ScoreInputs,
27 pub scores: Scores,
28}
29
30// TODO: Consider having "Concise" style corresponding to `concise_text`
31#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
32pub enum SnippetStyle {
33 Signature,
34 Declaration,
35}
36
37impl ScoredSnippet {
38 /// Returns the score for this snippet with the specified style.
39 pub fn score(&self, style: SnippetStyle) -> f32 {
40 match style {
41 SnippetStyle::Signature => self.scores.signature,
42 SnippetStyle::Declaration => self.scores.declaration,
43 }
44 }
45
46 pub fn size(&self, style: SnippetStyle) -> usize {
47 // TODO: how to handle truncation?
48 match &self.declaration {
49 Declaration::File { declaration, .. } => match style {
50 SnippetStyle::Signature => declaration.signature_range_in_text.len(),
51 SnippetStyle::Declaration => declaration.text.len(),
52 },
53 Declaration::Buffer { declaration, .. } => match style {
54 SnippetStyle::Signature => declaration.signature_range.len(),
55 SnippetStyle::Declaration => declaration.item_range.len(),
56 },
57 }
58 }
59
60 pub fn score_density(&self, style: SnippetStyle) -> f32 {
61 self.score(style) / (self.size(style)) as f32
62 }
63}
64
65pub fn scored_snippets(
66 index: &SyntaxIndexState,
67 excerpt: &EditPredictionExcerpt,
68 excerpt_text: &EditPredictionExcerptText,
69 identifier_to_references: HashMap<Identifier, Vec<Reference>>,
70 cursor_offset: usize,
71 current_buffer: &BufferSnapshot,
72) -> Vec<ScoredSnippet> {
73 let containing_range_identifier_occurrences =
74 IdentifierOccurrences::within_string(&excerpt_text.body);
75 let cursor_point = cursor_offset.to_point(¤t_buffer);
76
77 let start_point = Point::new(cursor_point.row.saturating_sub(2), 0);
78 let end_point = Point::new(cursor_point.row + 1, 0);
79 let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
80 ¤t_buffer
81 .text_for_range(start_point..end_point)
82 .collect::<String>(),
83 );
84
85 let mut snippets = identifier_to_references
86 .into_iter()
87 .flat_map(|(identifier, references)| {
88 let declarations =
89 index.declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier);
90 let declaration_count = declarations.len();
91
92 declarations
93 .iter()
94 .filter_map(|declaration| match declaration {
95 Declaration::Buffer {
96 buffer_id,
97 declaration: buffer_declaration,
98 ..
99 } => {
100 let is_same_file = buffer_id == ¤t_buffer.remote_id();
101
102 if is_same_file {
103 range_intersection(
104 &buffer_declaration.item_range.to_offset(¤t_buffer),
105 &excerpt.range,
106 )
107 .is_none()
108 .then(|| {
109 let declaration_line = buffer_declaration
110 .item_range
111 .start
112 .to_point(current_buffer)
113 .row;
114 (
115 true,
116 (cursor_point.row as i32 - declaration_line as i32)
117 .unsigned_abs(),
118 declaration,
119 )
120 })
121 } else {
122 Some((false, u32::MAX, declaration))
123 }
124 }
125 Declaration::File { .. } => {
126 // We can assume that a file declaration is in a different file,
127 // because the current one must be open
128 Some((false, u32::MAX, declaration))
129 }
130 })
131 .sorted_by_key(|&(_, distance, _)| distance)
132 .enumerate()
133 .map(
134 |(
135 declaration_line_distance_rank,
136 (is_same_file, declaration_line_distance, declaration),
137 )| {
138 let same_file_declaration_count = index.file_declaration_count(declaration);
139
140 score_snippet(
141 &identifier,
142 &references,
143 declaration.clone(),
144 is_same_file,
145 declaration_line_distance,
146 declaration_line_distance_rank,
147 same_file_declaration_count,
148 declaration_count,
149 &containing_range_identifier_occurrences,
150 &adjacent_identifier_occurrences,
151 cursor_point,
152 current_buffer,
153 )
154 },
155 )
156 .collect::<Vec<_>>()
157 })
158 .flatten()
159 .collect::<Vec<_>>();
160
161 snippets.sort_unstable_by_key(|snippet| {
162 OrderedFloat(
163 snippet
164 .score_density(SnippetStyle::Declaration)
165 .max(snippet.score_density(SnippetStyle::Signature)),
166 )
167 });
168
169 snippets
170}
171
172fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
173 let start = a.start.clone().max(b.start.clone());
174 let end = a.end.clone().min(b.end.clone());
175 if start < end {
176 Some(Range { start, end })
177 } else {
178 None
179 }
180}
181
182fn score_snippet(
183 identifier: &Identifier,
184 references: &[Reference],
185 declaration: Declaration,
186 is_same_file: bool,
187 declaration_line_distance: u32,
188 declaration_line_distance_rank: usize,
189 same_file_declaration_count: usize,
190 declaration_count: usize,
191 containing_range_identifier_occurrences: &IdentifierOccurrences,
192 adjacent_identifier_occurrences: &IdentifierOccurrences,
193 cursor: Point,
194 current_buffer: &BufferSnapshot,
195) -> Option<ScoredSnippet> {
196 let is_referenced_nearby = references
197 .iter()
198 .any(|r| r.region == ReferenceRegion::Nearby);
199 let is_referenced_in_breadcrumb = references
200 .iter()
201 .any(|r| r.region == ReferenceRegion::Breadcrumb);
202 let reference_count = references.len();
203 let reference_line_distance = references
204 .iter()
205 .map(|r| {
206 let reference_line = r.range.start.to_point(current_buffer).row as i32;
207 (cursor.row as i32 - reference_line).unsigned_abs()
208 })
209 .min()
210 .unwrap();
211
212 let item_source_occurrences = IdentifierOccurrences::within_string(&declaration.item_text().0);
213 let item_signature_occurrences =
214 IdentifierOccurrences::within_string(&declaration.signature_text().0);
215 let containing_range_vs_item_jaccard = jaccard_similarity(
216 containing_range_identifier_occurrences,
217 &item_source_occurrences,
218 );
219 let containing_range_vs_signature_jaccard = jaccard_similarity(
220 containing_range_identifier_occurrences,
221 &item_signature_occurrences,
222 );
223 let adjacent_vs_item_jaccard =
224 jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences);
225 let adjacent_vs_signature_jaccard =
226 jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences);
227
228 let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient(
229 containing_range_identifier_occurrences,
230 &item_source_occurrences,
231 );
232 let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient(
233 containing_range_identifier_occurrences,
234 &item_signature_occurrences,
235 );
236 let adjacent_vs_item_weighted_overlap =
237 weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences);
238 let adjacent_vs_signature_weighted_overlap =
239 weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
240
241 let score_components = ScoreInputs {
242 is_same_file,
243 is_referenced_nearby,
244 is_referenced_in_breadcrumb,
245 reference_line_distance,
246 declaration_line_distance,
247 declaration_line_distance_rank,
248 reference_count,
249 same_file_declaration_count,
250 declaration_count,
251 containing_range_vs_item_jaccard,
252 containing_range_vs_signature_jaccard,
253 adjacent_vs_item_jaccard,
254 adjacent_vs_signature_jaccard,
255 containing_range_vs_item_weighted_overlap,
256 containing_range_vs_signature_weighted_overlap,
257 adjacent_vs_item_weighted_overlap,
258 adjacent_vs_signature_weighted_overlap,
259 };
260
261 Some(ScoredSnippet {
262 identifier: identifier.clone(),
263 declaration: declaration,
264 scores: score_components.score(),
265 score_components,
266 })
267}
268
269#[derive(Clone, Debug, Serialize)]
270pub struct ScoreInputs {
271 pub is_same_file: bool,
272 pub is_referenced_nearby: bool,
273 pub is_referenced_in_breadcrumb: bool,
274 pub reference_count: usize,
275 pub same_file_declaration_count: usize,
276 pub declaration_count: usize,
277 pub reference_line_distance: u32,
278 pub declaration_line_distance: u32,
279 pub declaration_line_distance_rank: usize,
280 pub containing_range_vs_item_jaccard: f32,
281 pub containing_range_vs_signature_jaccard: f32,
282 pub adjacent_vs_item_jaccard: f32,
283 pub adjacent_vs_signature_jaccard: f32,
284 pub containing_range_vs_item_weighted_overlap: f32,
285 pub containing_range_vs_signature_weighted_overlap: f32,
286 pub adjacent_vs_item_weighted_overlap: f32,
287 pub adjacent_vs_signature_weighted_overlap: f32,
288}
289
290#[derive(Clone, Debug, Serialize)]
291pub struct Scores {
292 pub signature: f32,
293 pub declaration: f32,
294}
295
296impl ScoreInputs {
297 fn score(&self) -> Scores {
298 // Score related to how likely this is the correct declaration, range 0 to 1
299 let accuracy_score = if self.is_same_file {
300 // TODO: use declaration_line_distance_rank
301 1.0 / self.same_file_declaration_count as f32
302 } else {
303 1.0 / self.declaration_count as f32
304 };
305
306 // Score related to the distance between the reference and cursor, range 0 to 1
307 let distance_score = if self.is_referenced_nearby {
308 1.0 / (1.0 + self.reference_line_distance as f32 / 10.0).powf(2.0)
309 } else {
310 // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
311 0.5
312 };
313
314 // For now instead of linear combination, the scores are just multiplied together.
315 let combined_score = 10.0 * accuracy_score * distance_score;
316
317 Scores {
318 signature: combined_score * self.containing_range_vs_signature_weighted_overlap,
319 // declaration score gets boosted both by being multiplied by 2 and by there being more
320 // weighted overlap.
321 declaration: 2.0 * combined_score * self.containing_range_vs_item_weighted_overlap,
322 }
323 }
324}