1use itertools::Itertools as _;
2use language::BufferSnapshot;
3use ordered_float::OrderedFloat;
4use serde::Serialize;
5use std::{collections::HashMap, ops::Range};
6use strum::EnumIter;
7use text::{OffsetRangeExt, Point, ToPoint};
8
9use crate::{
10 Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier,
11 reference::{Reference, ReferenceRegion},
12 syntax_index::SyntaxIndexState,
13 text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
14};
15
16const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
17
18// TODO:
19//
20// * Consider adding declaration_file_count
21
22#[derive(Clone, Debug)]
23pub struct ScoredSnippet {
24 pub identifier: Identifier,
25 pub declaration: Declaration,
26 pub score_components: ScoreInputs,
27 pub scores: Scores,
28}
29
30// TODO: Consider having "Concise" style corresponding to `concise_text`
31#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
32pub enum SnippetStyle {
33 Signature,
34 Declaration,
35}
36
37impl ScoredSnippet {
38 /// Returns the score for this snippet with the specified style.
39 pub fn score(&self, style: SnippetStyle) -> f32 {
40 match style {
41 SnippetStyle::Signature => self.scores.signature,
42 SnippetStyle::Declaration => self.scores.declaration,
43 }
44 }
45
46 pub fn size(&self, style: SnippetStyle) -> usize {
47 // TODO: how to handle truncation?
48 match &self.declaration {
49 Declaration::File { declaration, .. } => match style {
50 SnippetStyle::Signature => declaration.signature_range_in_text.len(),
51 SnippetStyle::Declaration => declaration.text.len(),
52 },
53 Declaration::Buffer { declaration, .. } => match style {
54 SnippetStyle::Signature => declaration.signature_range.len(),
55 SnippetStyle::Declaration => declaration.item_range.len(),
56 },
57 }
58 }
59
60 pub fn score_density(&self, style: SnippetStyle) -> f32 {
61 self.score(style) / (self.size(style)) as f32
62 }
63}
64
65pub fn scored_snippets(
66 index: &SyntaxIndexState,
67 excerpt: &EditPredictionExcerpt,
68 excerpt_text: &EditPredictionExcerptText,
69 identifier_to_references: HashMap<Identifier, Vec<Reference>>,
70 cursor_offset: usize,
71 current_buffer: &BufferSnapshot,
72) -> Vec<ScoredSnippet> {
73 let containing_range_identifier_occurrences =
74 IdentifierOccurrences::within_string(&excerpt_text.body);
75 let cursor_point = cursor_offset.to_point(¤t_buffer);
76
77 let start_point = Point::new(cursor_point.row.saturating_sub(2), 0);
78 let end_point = Point::new(cursor_point.row + 1, 0);
79 let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
80 ¤t_buffer
81 .text_for_range(start_point..end_point)
82 .collect::<String>(),
83 );
84
85 let mut snippets = identifier_to_references
86 .into_iter()
87 .flat_map(|(identifier, references)| {
88 let declarations =
89 index.declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier);
90 let declaration_count = declarations.len();
91
92 declarations
93 .iter()
94 .filter_map(|declaration| match declaration {
95 Declaration::Buffer {
96 buffer_id,
97 declaration: buffer_declaration,
98 ..
99 } => {
100 let is_same_file = buffer_id == ¤t_buffer.remote_id();
101
102 if is_same_file {
103 range_intersection(
104 &buffer_declaration.item_range.to_offset(¤t_buffer),
105 &excerpt.range,
106 )
107 .is_none()
108 .then(|| {
109 let declaration_line = buffer_declaration
110 .item_range
111 .start
112 .to_point(current_buffer)
113 .row;
114 (
115 true,
116 (cursor_point.row as i32 - declaration_line as i32)
117 .unsigned_abs(),
118 declaration,
119 )
120 })
121 } else {
122 // TODO should we prefer the current file instead?
123 Some((false, 0, declaration))
124 }
125 }
126 Declaration::File { .. } => {
127 // TODO should we prefer the current file instead?
128 // We can assume that a file declaration is in a different file,
129 // because the current one must be open
130 Some((false, 0, declaration))
131 }
132 })
133 .sorted_by_key(|&(_, distance, _)| distance)
134 .enumerate()
135 .map(
136 |(
137 declaration_line_distance_rank,
138 (is_same_file, declaration_line_distance, declaration),
139 )| {
140 let same_file_declaration_count = index.file_declaration_count(declaration);
141
142 score_snippet(
143 &identifier,
144 &references,
145 declaration.clone(),
146 is_same_file,
147 declaration_line_distance,
148 declaration_line_distance_rank,
149 same_file_declaration_count,
150 declaration_count,
151 &containing_range_identifier_occurrences,
152 &adjacent_identifier_occurrences,
153 cursor_point,
154 current_buffer,
155 )
156 },
157 )
158 .collect::<Vec<_>>()
159 })
160 .flatten()
161 .collect::<Vec<_>>();
162
163 snippets.sort_unstable_by_key(|snippet| {
164 OrderedFloat(
165 snippet
166 .score_density(SnippetStyle::Declaration)
167 .max(snippet.score_density(SnippetStyle::Signature)),
168 )
169 });
170
171 snippets
172}
173
174fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
175 let start = a.start.clone().max(b.start.clone());
176 let end = a.end.clone().min(b.end.clone());
177 if start < end {
178 Some(Range { start, end })
179 } else {
180 None
181 }
182}
183
184fn score_snippet(
185 identifier: &Identifier,
186 references: &[Reference],
187 declaration: Declaration,
188 is_same_file: bool,
189 declaration_line_distance: u32,
190 declaration_line_distance_rank: usize,
191 same_file_declaration_count: usize,
192 declaration_count: usize,
193 containing_range_identifier_occurrences: &IdentifierOccurrences,
194 adjacent_identifier_occurrences: &IdentifierOccurrences,
195 cursor: Point,
196 current_buffer: &BufferSnapshot,
197) -> Option<ScoredSnippet> {
198 let is_referenced_nearby = references
199 .iter()
200 .any(|r| r.region == ReferenceRegion::Nearby);
201 let is_referenced_in_breadcrumb = references
202 .iter()
203 .any(|r| r.region == ReferenceRegion::Breadcrumb);
204 let reference_count = references.len();
205 let reference_line_distance = references
206 .iter()
207 .map(|r| {
208 let reference_line = r.range.start.to_point(current_buffer).row as i32;
209 (cursor.row as i32 - reference_line).unsigned_abs()
210 })
211 .min()
212 .unwrap();
213
214 let item_source_occurrences = IdentifierOccurrences::within_string(&declaration.item_text().0);
215 let item_signature_occurrences =
216 IdentifierOccurrences::within_string(&declaration.signature_text().0);
217 let containing_range_vs_item_jaccard = jaccard_similarity(
218 containing_range_identifier_occurrences,
219 &item_source_occurrences,
220 );
221 let containing_range_vs_signature_jaccard = jaccard_similarity(
222 containing_range_identifier_occurrences,
223 &item_signature_occurrences,
224 );
225 let adjacent_vs_item_jaccard =
226 jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences);
227 let adjacent_vs_signature_jaccard =
228 jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences);
229
230 let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient(
231 containing_range_identifier_occurrences,
232 &item_source_occurrences,
233 );
234 let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient(
235 containing_range_identifier_occurrences,
236 &item_signature_occurrences,
237 );
238 let adjacent_vs_item_weighted_overlap =
239 weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences);
240 let adjacent_vs_signature_weighted_overlap =
241 weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
242
243 let score_components = ScoreInputs {
244 is_same_file,
245 is_referenced_nearby,
246 is_referenced_in_breadcrumb,
247 reference_line_distance,
248 declaration_line_distance,
249 declaration_line_distance_rank,
250 reference_count,
251 same_file_declaration_count,
252 declaration_count,
253 containing_range_vs_item_jaccard,
254 containing_range_vs_signature_jaccard,
255 adjacent_vs_item_jaccard,
256 adjacent_vs_signature_jaccard,
257 containing_range_vs_item_weighted_overlap,
258 containing_range_vs_signature_weighted_overlap,
259 adjacent_vs_item_weighted_overlap,
260 adjacent_vs_signature_weighted_overlap,
261 };
262
263 Some(ScoredSnippet {
264 identifier: identifier.clone(),
265 declaration: declaration,
266 scores: score_components.score(),
267 score_components,
268 })
269}
270
271#[derive(Clone, Debug, Serialize)]
272pub struct ScoreInputs {
273 pub is_same_file: bool,
274 pub is_referenced_nearby: bool,
275 pub is_referenced_in_breadcrumb: bool,
276 pub reference_count: usize,
277 pub same_file_declaration_count: usize,
278 pub declaration_count: usize,
279 pub reference_line_distance: u32,
280 pub declaration_line_distance: u32,
281 pub declaration_line_distance_rank: usize,
282 pub containing_range_vs_item_jaccard: f32,
283 pub containing_range_vs_signature_jaccard: f32,
284 pub adjacent_vs_item_jaccard: f32,
285 pub adjacent_vs_signature_jaccard: f32,
286 pub containing_range_vs_item_weighted_overlap: f32,
287 pub containing_range_vs_signature_weighted_overlap: f32,
288 pub adjacent_vs_item_weighted_overlap: f32,
289 pub adjacent_vs_signature_weighted_overlap: f32,
290}
291
292#[derive(Clone, Debug, Serialize)]
293pub struct Scores {
294 pub signature: f32,
295 pub declaration: f32,
296}
297
298impl ScoreInputs {
299 fn score(&self) -> Scores {
300 // Score related to how likely this is the correct declaration, range 0 to 1
301 let accuracy_score = if self.is_same_file {
302 // TODO: use declaration_line_distance_rank
303 1.0 / self.same_file_declaration_count as f32
304 } else {
305 1.0 / self.declaration_count as f32
306 };
307
308 // Score related to the distance between the reference and cursor, range 0 to 1
309 let distance_score = if self.is_referenced_nearby {
310 1.0 / (1.0 + self.reference_line_distance as f32 / 10.0).powf(2.0)
311 } else {
312 // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
313 0.5
314 };
315
316 // For now instead of linear combination, the scores are just multiplied together.
317 let combined_score = 10.0 * accuracy_score * distance_score;
318
319 Scores {
320 signature: combined_score * self.containing_range_vs_signature_weighted_overlap,
321 // declaration score gets boosted both by being multiplied by 2 and by there being more
322 // weighted overlap.
323 declaration: 2.0 * combined_score * self.containing_range_vs_item_weighted_overlap,
324 }
325 }
326}