1use cloud_llm_client::predict_edits_v3::ScoreComponents;
2use itertools::Itertools as _;
3use language::BufferSnapshot;
4use ordered_float::OrderedFloat;
5use serde::Serialize;
6use std::{cmp::Reverse, collections::HashMap, ops::Range};
7use strum::EnumIter;
8use text::{Point, ToPoint};
9
10use crate::{
11 Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier,
12 reference::{Reference, ReferenceRegion},
13 syntax_index::SyntaxIndexState,
14 text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
15};
16
17const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
18
19#[derive(Clone, Debug)]
20pub struct ScoredSnippet {
21 pub identifier: Identifier,
22 pub declaration: Declaration,
23 pub score_components: ScoreComponents,
24 pub scores: Scores,
25}
26
27#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
28pub enum SnippetStyle {
29 Signature,
30 Declaration,
31}
32
33impl ScoredSnippet {
34 /// Returns the score for this snippet with the specified style.
35 pub fn score(&self, style: SnippetStyle) -> f32 {
36 match style {
37 SnippetStyle::Signature => self.scores.signature,
38 SnippetStyle::Declaration => self.scores.declaration,
39 }
40 }
41
42 pub fn size(&self, style: SnippetStyle) -> usize {
43 match &self.declaration {
44 Declaration::File { declaration, .. } => match style {
45 SnippetStyle::Signature => declaration.signature_range.len(),
46 SnippetStyle::Declaration => declaration.text.len(),
47 },
48 Declaration::Buffer { declaration, .. } => match style {
49 SnippetStyle::Signature => declaration.signature_range.len(),
50 SnippetStyle::Declaration => declaration.item_range.len(),
51 },
52 }
53 }
54
55 pub fn score_density(&self, style: SnippetStyle) -> f32 {
56 self.score(style) / (self.size(style)) as f32
57 }
58}
59
60pub fn scored_snippets(
61 index: &SyntaxIndexState,
62 excerpt: &EditPredictionExcerpt,
63 excerpt_text: &EditPredictionExcerptText,
64 identifier_to_references: HashMap<Identifier, Vec<Reference>>,
65 cursor_offset: usize,
66 current_buffer: &BufferSnapshot,
67) -> Vec<ScoredSnippet> {
68 let containing_range_identifier_occurrences =
69 IdentifierOccurrences::within_string(&excerpt_text.body);
70 let cursor_point = cursor_offset.to_point(¤t_buffer);
71
72 let start_point = Point::new(cursor_point.row.saturating_sub(2), 0);
73 let end_point = Point::new(cursor_point.row + 1, 0);
74 let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
75 ¤t_buffer
76 .text_for_range(start_point..end_point)
77 .collect::<String>(),
78 );
79
80 let mut snippets = identifier_to_references
81 .into_iter()
82 .flat_map(|(identifier, references)| {
83 let declarations =
84 index.declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier);
85 let declaration_count = declarations.len();
86
87 declarations
88 .into_iter()
89 .filter_map(|(declaration_id, declaration)| match declaration {
90 Declaration::Buffer {
91 buffer_id,
92 declaration: buffer_declaration,
93 ..
94 } => {
95 let is_same_file = buffer_id == ¤t_buffer.remote_id();
96
97 if is_same_file {
98 let overlaps_excerpt =
99 range_intersection(&buffer_declaration.item_range, &excerpt.range)
100 .is_some();
101 if overlaps_excerpt
102 || excerpt
103 .parent_declarations
104 .iter()
105 .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id)
106 {
107 None
108 } else {
109 let declaration_line = buffer_declaration
110 .item_range
111 .start
112 .to_point(current_buffer)
113 .row;
114 Some((
115 true,
116 (cursor_point.row as i32 - declaration_line as i32)
117 .unsigned_abs(),
118 declaration,
119 ))
120 }
121 } else {
122 Some((false, u32::MAX, declaration))
123 }
124 }
125 Declaration::File { .. } => {
126 // We can assume that a file declaration is in a different file,
127 // because the current one must be open
128 Some((false, u32::MAX, declaration))
129 }
130 })
131 .sorted_by_key(|&(_, distance, _)| distance)
132 .enumerate()
133 .map(
134 |(
135 declaration_line_distance_rank,
136 (is_same_file, declaration_line_distance, declaration),
137 )| {
138 let same_file_declaration_count = index.file_declaration_count(declaration);
139
140 score_snippet(
141 &identifier,
142 &references,
143 declaration.clone(),
144 is_same_file,
145 declaration_line_distance,
146 declaration_line_distance_rank,
147 same_file_declaration_count,
148 declaration_count,
149 &containing_range_identifier_occurrences,
150 &adjacent_identifier_occurrences,
151 cursor_point,
152 current_buffer,
153 )
154 },
155 )
156 .collect::<Vec<_>>()
157 })
158 .flatten()
159 .collect::<Vec<_>>();
160
161 snippets.sort_unstable_by_key(|snippet| {
162 let score_density = snippet
163 .score_density(SnippetStyle::Declaration)
164 .max(snippet.score_density(SnippetStyle::Signature));
165 Reverse(OrderedFloat(score_density))
166 });
167
168 snippets
169}
170
171fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
172 let start = a.start.clone().max(b.start.clone());
173 let end = a.end.clone().min(b.end.clone());
174 if start < end {
175 Some(Range { start, end })
176 } else {
177 None
178 }
179}
180
181fn score_snippet(
182 identifier: &Identifier,
183 references: &[Reference],
184 declaration: Declaration,
185 is_same_file: bool,
186 declaration_line_distance: u32,
187 declaration_line_distance_rank: usize,
188 same_file_declaration_count: usize,
189 declaration_count: usize,
190 containing_range_identifier_occurrences: &IdentifierOccurrences,
191 adjacent_identifier_occurrences: &IdentifierOccurrences,
192 cursor: Point,
193 current_buffer: &BufferSnapshot,
194) -> Option<ScoredSnippet> {
195 let is_referenced_nearby = references
196 .iter()
197 .any(|r| r.region == ReferenceRegion::Nearby);
198 let is_referenced_in_breadcrumb = references
199 .iter()
200 .any(|r| r.region == ReferenceRegion::Breadcrumb);
201 let reference_count = references.len();
202 let reference_line_distance = references
203 .iter()
204 .map(|r| {
205 let reference_line = r.range.start.to_point(current_buffer).row as i32;
206 (cursor.row as i32 - reference_line).unsigned_abs()
207 })
208 .min()
209 .unwrap();
210
211 let item_source_occurrences = IdentifierOccurrences::within_string(&declaration.item_text().0);
212 let item_signature_occurrences =
213 IdentifierOccurrences::within_string(&declaration.signature_text().0);
214 let containing_range_vs_item_jaccard = jaccard_similarity(
215 containing_range_identifier_occurrences,
216 &item_source_occurrences,
217 );
218 let containing_range_vs_signature_jaccard = jaccard_similarity(
219 containing_range_identifier_occurrences,
220 &item_signature_occurrences,
221 );
222 let adjacent_vs_item_jaccard =
223 jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences);
224 let adjacent_vs_signature_jaccard =
225 jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences);
226
227 let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient(
228 containing_range_identifier_occurrences,
229 &item_source_occurrences,
230 );
231 let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient(
232 containing_range_identifier_occurrences,
233 &item_signature_occurrences,
234 );
235 let adjacent_vs_item_weighted_overlap =
236 weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences);
237 let adjacent_vs_signature_weighted_overlap =
238 weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
239
240 // TODO: Consider adding declaration_file_count
241 let score_components = ScoreComponents {
242 is_same_file,
243 is_referenced_nearby,
244 is_referenced_in_breadcrumb,
245 reference_line_distance,
246 declaration_line_distance,
247 declaration_line_distance_rank,
248 reference_count,
249 same_file_declaration_count,
250 declaration_count,
251 containing_range_vs_item_jaccard,
252 containing_range_vs_signature_jaccard,
253 adjacent_vs_item_jaccard,
254 adjacent_vs_signature_jaccard,
255 containing_range_vs_item_weighted_overlap,
256 containing_range_vs_signature_weighted_overlap,
257 adjacent_vs_item_weighted_overlap,
258 adjacent_vs_signature_weighted_overlap,
259 };
260
261 Some(ScoredSnippet {
262 identifier: identifier.clone(),
263 declaration: declaration,
264 scores: Scores::score(&score_components),
265 score_components,
266 })
267}
268
269#[derive(Clone, Debug, Serialize)]
270pub struct Scores {
271 pub signature: f32,
272 pub declaration: f32,
273}
274
275impl Scores {
276 fn score(components: &ScoreComponents) -> Scores {
277 // TODO: handle truncation
278
279 // Score related to how likely this is the correct declaration, range 0 to 1
280 let accuracy_score = if components.is_same_file {
281 // TODO: use declaration_line_distance_rank
282 1.0 / components.same_file_declaration_count as f32
283 } else {
284 1.0 / components.declaration_count as f32
285 };
286
287 // Score related to the distance between the reference and cursor, range 0 to 1
288 let distance_score = if components.is_referenced_nearby {
289 1.0 / (1.0 + components.reference_line_distance as f32 / 10.0).powf(2.0)
290 } else {
291 // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
292 0.5
293 };
294
295 // For now instead of linear combination, the scores are just multiplied together.
296 let combined_score = 10.0 * accuracy_score * distance_score;
297
298 Scores {
299 signature: combined_score * components.containing_range_vs_signature_weighted_overlap,
300 // declaration score gets boosted both by being multiplied by 2 and by there being more
301 // weighted overlap.
302 declaration: 2.0
303 * combined_score
304 * components.containing_range_vs_item_weighted_overlap,
305 }
306 }
307}