1use cloud_llm_client::predict_edits_v3::ScoreComponents;
2use itertools::Itertools as _;
3use language::BufferSnapshot;
4use ordered_float::OrderedFloat;
5use serde::Serialize;
6use std::{collections::HashMap, ops::Range};
7use strum::EnumIter;
8use text::{Point, ToPoint};
9
10use crate::{
11 Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier,
12 reference::{Reference, ReferenceRegion},
13 syntax_index::SyntaxIndexState,
14 text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
15};
16
17const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
18
19// TODO:
20//
21// * Consider adding declaration_file_count
22
23#[derive(Clone, Debug)]
24pub struct ScoredSnippet {
25 pub identifier: Identifier,
26 pub declaration: Declaration,
27 pub score_components: ScoreComponents,
28 pub scores: Scores,
29}
30
31// TODO: Consider having "Concise" style corresponding to `concise_text`
32#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
33pub enum SnippetStyle {
34 Signature,
35 Declaration,
36}
37
38impl ScoredSnippet {
39 /// Returns the score for this snippet with the specified style.
40 pub fn score(&self, style: SnippetStyle) -> f32 {
41 match style {
42 SnippetStyle::Signature => self.scores.signature,
43 SnippetStyle::Declaration => self.scores.declaration,
44 }
45 }
46
47 pub fn size(&self, style: SnippetStyle) -> usize {
48 // TODO: how to handle truncation?
49 match &self.declaration {
50 Declaration::File { declaration, .. } => match style {
51 SnippetStyle::Signature => declaration.signature_range_in_text.len(),
52 SnippetStyle::Declaration => declaration.text.len(),
53 },
54 Declaration::Buffer { declaration, .. } => match style {
55 SnippetStyle::Signature => declaration.signature_range.len(),
56 SnippetStyle::Declaration => declaration.item_range.len(),
57 },
58 }
59 }
60
61 pub fn score_density(&self, style: SnippetStyle) -> f32 {
62 self.score(style) / (self.size(style)) as f32
63 }
64}
65
66pub fn scored_snippets(
67 index: &SyntaxIndexState,
68 excerpt: &EditPredictionExcerpt,
69 excerpt_text: &EditPredictionExcerptText,
70 identifier_to_references: HashMap<Identifier, Vec<Reference>>,
71 cursor_offset: usize,
72 current_buffer: &BufferSnapshot,
73) -> Vec<ScoredSnippet> {
74 let containing_range_identifier_occurrences =
75 IdentifierOccurrences::within_string(&excerpt_text.body);
76 let cursor_point = cursor_offset.to_point(¤t_buffer);
77
78 let start_point = Point::new(cursor_point.row.saturating_sub(2), 0);
79 let end_point = Point::new(cursor_point.row + 1, 0);
80 let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
81 ¤t_buffer
82 .text_for_range(start_point..end_point)
83 .collect::<String>(),
84 );
85
86 let mut snippets = identifier_to_references
87 .into_iter()
88 .flat_map(|(identifier, references)| {
89 let declarations =
90 index.declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier);
91 let declaration_count = declarations.len();
92
93 declarations
94 .into_iter()
95 .filter_map(|(declaration_id, declaration)| match declaration {
96 Declaration::Buffer {
97 buffer_id,
98 declaration: buffer_declaration,
99 ..
100 } => {
101 let is_same_file = buffer_id == ¤t_buffer.remote_id();
102
103 if is_same_file {
104 let overlaps_excerpt =
105 range_intersection(&buffer_declaration.item_range, &excerpt.range)
106 .is_some();
107 if overlaps_excerpt
108 || excerpt
109 .parent_declarations
110 .iter()
111 .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id)
112 {
113 None
114 } else {
115 let declaration_line = buffer_declaration
116 .item_range
117 .start
118 .to_point(current_buffer)
119 .row;
120 Some((
121 true,
122 (cursor_point.row as i32 - declaration_line as i32)
123 .unsigned_abs(),
124 declaration,
125 ))
126 }
127 } else {
128 Some((false, u32::MAX, declaration))
129 }
130 }
131 Declaration::File { .. } => {
132 // We can assume that a file declaration is in a different file,
133 // because the current one must be open
134 Some((false, u32::MAX, declaration))
135 }
136 })
137 .sorted_by_key(|&(_, distance, _)| distance)
138 .enumerate()
139 .map(
140 |(
141 declaration_line_distance_rank,
142 (is_same_file, declaration_line_distance, declaration),
143 )| {
144 let same_file_declaration_count = index.file_declaration_count(declaration);
145
146 score_snippet(
147 &identifier,
148 &references,
149 declaration.clone(),
150 is_same_file,
151 declaration_line_distance,
152 declaration_line_distance_rank,
153 same_file_declaration_count,
154 declaration_count,
155 &containing_range_identifier_occurrences,
156 &adjacent_identifier_occurrences,
157 cursor_point,
158 current_buffer,
159 )
160 },
161 )
162 .collect::<Vec<_>>()
163 })
164 .flatten()
165 .collect::<Vec<_>>();
166
167 snippets.sort_unstable_by_key(|snippet| {
168 OrderedFloat(
169 snippet
170 .score_density(SnippetStyle::Declaration)
171 .max(snippet.score_density(SnippetStyle::Signature)),
172 )
173 });
174
175 snippets
176}
177
178fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
179 let start = a.start.clone().max(b.start.clone());
180 let end = a.end.clone().min(b.end.clone());
181 if start < end {
182 Some(Range { start, end })
183 } else {
184 None
185 }
186}
187
188fn score_snippet(
189 identifier: &Identifier,
190 references: &[Reference],
191 declaration: Declaration,
192 is_same_file: bool,
193 declaration_line_distance: u32,
194 declaration_line_distance_rank: usize,
195 same_file_declaration_count: usize,
196 declaration_count: usize,
197 containing_range_identifier_occurrences: &IdentifierOccurrences,
198 adjacent_identifier_occurrences: &IdentifierOccurrences,
199 cursor: Point,
200 current_buffer: &BufferSnapshot,
201) -> Option<ScoredSnippet> {
202 let is_referenced_nearby = references
203 .iter()
204 .any(|r| r.region == ReferenceRegion::Nearby);
205 let is_referenced_in_breadcrumb = references
206 .iter()
207 .any(|r| r.region == ReferenceRegion::Breadcrumb);
208 let reference_count = references.len();
209 let reference_line_distance = references
210 .iter()
211 .map(|r| {
212 let reference_line = r.range.start.to_point(current_buffer).row as i32;
213 (cursor.row as i32 - reference_line).unsigned_abs()
214 })
215 .min()
216 .unwrap();
217
218 let item_source_occurrences = IdentifierOccurrences::within_string(&declaration.item_text().0);
219 let item_signature_occurrences =
220 IdentifierOccurrences::within_string(&declaration.signature_text().0);
221 let containing_range_vs_item_jaccard = jaccard_similarity(
222 containing_range_identifier_occurrences,
223 &item_source_occurrences,
224 );
225 let containing_range_vs_signature_jaccard = jaccard_similarity(
226 containing_range_identifier_occurrences,
227 &item_signature_occurrences,
228 );
229 let adjacent_vs_item_jaccard =
230 jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences);
231 let adjacent_vs_signature_jaccard =
232 jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences);
233
234 let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient(
235 containing_range_identifier_occurrences,
236 &item_source_occurrences,
237 );
238 let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient(
239 containing_range_identifier_occurrences,
240 &item_signature_occurrences,
241 );
242 let adjacent_vs_item_weighted_overlap =
243 weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences);
244 let adjacent_vs_signature_weighted_overlap =
245 weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
246
247 let score_components = ScoreComponents {
248 is_same_file,
249 is_referenced_nearby,
250 is_referenced_in_breadcrumb,
251 reference_line_distance,
252 declaration_line_distance,
253 declaration_line_distance_rank,
254 reference_count,
255 same_file_declaration_count,
256 declaration_count,
257 containing_range_vs_item_jaccard,
258 containing_range_vs_signature_jaccard,
259 adjacent_vs_item_jaccard,
260 adjacent_vs_signature_jaccard,
261 containing_range_vs_item_weighted_overlap,
262 containing_range_vs_signature_weighted_overlap,
263 adjacent_vs_item_weighted_overlap,
264 adjacent_vs_signature_weighted_overlap,
265 };
266
267 Some(ScoredSnippet {
268 identifier: identifier.clone(),
269 declaration: declaration,
270 scores: Scores::score(&score_components),
271 score_components,
272 })
273}
274
275#[derive(Clone, Debug, Serialize)]
276pub struct Scores {
277 pub signature: f32,
278 pub declaration: f32,
279}
280
281impl Scores {
282 fn score(components: &ScoreComponents) -> Scores {
283 // Score related to how likely this is the correct declaration, range 0 to 1
284 let accuracy_score = if components.is_same_file {
285 // TODO: use declaration_line_distance_rank
286 1.0 / components.same_file_declaration_count as f32
287 } else {
288 1.0 / components.declaration_count as f32
289 };
290
291 // Score related to the distance between the reference and cursor, range 0 to 1
292 let distance_score = if components.is_referenced_nearby {
293 1.0 / (1.0 + components.reference_line_distance as f32 / 10.0).powf(2.0)
294 } else {
295 // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
296 0.5
297 };
298
299 // For now instead of linear combination, the scores are just multiplied together.
300 let combined_score = 10.0 * accuracy_score * distance_score;
301
302 Scores {
303 signature: combined_score * components.containing_range_vs_signature_weighted_overlap,
304 // declaration score gets boosted both by being multiplied by 2 and by there being more
305 // weighted overlap.
306 declaration: 2.0
307 * combined_score
308 * components.containing_range_vs_item_weighted_overlap,
309 }
310 }
311}