1use cloud_llm_client::predict_edits_v3::ScoreComponents;
2use itertools::Itertools as _;
3use language::BufferSnapshot;
4use ordered_float::OrderedFloat;
5use serde::Serialize;
6use std::{collections::HashMap, ops::Range};
7use strum::EnumIter;
8use text::{Point, ToPoint};
9
10use crate::{
11 Declaration, EditPredictionExcerpt, EditPredictionExcerptText, Identifier,
12 reference::{Reference, ReferenceRegion},
13 syntax_index::SyntaxIndexState,
14 text_similarity::{IdentifierOccurrences, jaccard_similarity, weighted_overlap_coefficient},
15};
16
17const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
18
19#[derive(Clone, Debug)]
20pub struct ScoredSnippet {
21 pub identifier: Identifier,
22 pub declaration: Declaration,
23 pub score_components: ScoreComponents,
24 pub scores: Scores,
25}
26
27#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
28pub enum SnippetStyle {
29 Signature,
30 Declaration,
31}
32
33impl ScoredSnippet {
34 /// Returns the score for this snippet with the specified style.
35 pub fn score(&self, style: SnippetStyle) -> f32 {
36 match style {
37 SnippetStyle::Signature => self.scores.signature,
38 SnippetStyle::Declaration => self.scores.declaration,
39 }
40 }
41
42 pub fn size(&self, style: SnippetStyle) -> usize {
43 // TODO: how to handle truncation?
44 match &self.declaration {
45 Declaration::File { declaration, .. } => match style {
46 SnippetStyle::Signature => declaration.signature_range_in_text.len(),
47 SnippetStyle::Declaration => declaration.text.len(),
48 },
49 Declaration::Buffer { declaration, .. } => match style {
50 SnippetStyle::Signature => declaration.signature_range.len(),
51 SnippetStyle::Declaration => declaration.item_range.len(),
52 },
53 }
54 }
55
56 pub fn score_density(&self, style: SnippetStyle) -> f32 {
57 self.score(style) / (self.size(style)) as f32
58 }
59}
60
61pub fn scored_snippets(
62 index: &SyntaxIndexState,
63 excerpt: &EditPredictionExcerpt,
64 excerpt_text: &EditPredictionExcerptText,
65 identifier_to_references: HashMap<Identifier, Vec<Reference>>,
66 cursor_offset: usize,
67 current_buffer: &BufferSnapshot,
68) -> Vec<ScoredSnippet> {
69 let containing_range_identifier_occurrences =
70 IdentifierOccurrences::within_string(&excerpt_text.body);
71 let cursor_point = cursor_offset.to_point(¤t_buffer);
72
73 let start_point = Point::new(cursor_point.row.saturating_sub(2), 0);
74 let end_point = Point::new(cursor_point.row + 1, 0);
75 let adjacent_identifier_occurrences = IdentifierOccurrences::within_string(
76 ¤t_buffer
77 .text_for_range(start_point..end_point)
78 .collect::<String>(),
79 );
80
81 let mut snippets = identifier_to_references
82 .into_iter()
83 .flat_map(|(identifier, references)| {
84 let declarations =
85 index.declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier);
86 let declaration_count = declarations.len();
87
88 declarations
89 .into_iter()
90 .filter_map(|(declaration_id, declaration)| match declaration {
91 Declaration::Buffer {
92 buffer_id,
93 declaration: buffer_declaration,
94 ..
95 } => {
96 let is_same_file = buffer_id == ¤t_buffer.remote_id();
97
98 if is_same_file {
99 let overlaps_excerpt =
100 range_intersection(&buffer_declaration.item_range, &excerpt.range)
101 .is_some();
102 if overlaps_excerpt
103 || excerpt
104 .parent_declarations
105 .iter()
106 .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id)
107 {
108 None
109 } else {
110 let declaration_line = buffer_declaration
111 .item_range
112 .start
113 .to_point(current_buffer)
114 .row;
115 Some((
116 true,
117 (cursor_point.row as i32 - declaration_line as i32)
118 .unsigned_abs(),
119 declaration,
120 ))
121 }
122 } else {
123 Some((false, u32::MAX, declaration))
124 }
125 }
126 Declaration::File { .. } => {
127 // We can assume that a file declaration is in a different file,
128 // because the current one must be open
129 Some((false, u32::MAX, declaration))
130 }
131 })
132 .sorted_by_key(|&(_, distance, _)| distance)
133 .enumerate()
134 .map(
135 |(
136 declaration_line_distance_rank,
137 (is_same_file, declaration_line_distance, declaration),
138 )| {
139 let same_file_declaration_count = index.file_declaration_count(declaration);
140
141 score_snippet(
142 &identifier,
143 &references,
144 declaration.clone(),
145 is_same_file,
146 declaration_line_distance,
147 declaration_line_distance_rank,
148 same_file_declaration_count,
149 declaration_count,
150 &containing_range_identifier_occurrences,
151 &adjacent_identifier_occurrences,
152 cursor_point,
153 current_buffer,
154 )
155 },
156 )
157 .collect::<Vec<_>>()
158 })
159 .flatten()
160 .collect::<Vec<_>>();
161
162 snippets.sort_unstable_by_key(|snippet| {
163 OrderedFloat(
164 snippet
165 .score_density(SnippetStyle::Declaration)
166 .max(snippet.score_density(SnippetStyle::Signature)),
167 )
168 });
169
170 snippets
171}
172
173fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
174 let start = a.start.clone().max(b.start.clone());
175 let end = a.end.clone().min(b.end.clone());
176 if start < end {
177 Some(Range { start, end })
178 } else {
179 None
180 }
181}
182
183fn score_snippet(
184 identifier: &Identifier,
185 references: &[Reference],
186 declaration: Declaration,
187 is_same_file: bool,
188 declaration_line_distance: u32,
189 declaration_line_distance_rank: usize,
190 same_file_declaration_count: usize,
191 declaration_count: usize,
192 containing_range_identifier_occurrences: &IdentifierOccurrences,
193 adjacent_identifier_occurrences: &IdentifierOccurrences,
194 cursor: Point,
195 current_buffer: &BufferSnapshot,
196) -> Option<ScoredSnippet> {
197 let is_referenced_nearby = references
198 .iter()
199 .any(|r| r.region == ReferenceRegion::Nearby);
200 let is_referenced_in_breadcrumb = references
201 .iter()
202 .any(|r| r.region == ReferenceRegion::Breadcrumb);
203 let reference_count = references.len();
204 let reference_line_distance = references
205 .iter()
206 .map(|r| {
207 let reference_line = r.range.start.to_point(current_buffer).row as i32;
208 (cursor.row as i32 - reference_line).unsigned_abs()
209 })
210 .min()
211 .unwrap();
212
213 let item_source_occurrences = IdentifierOccurrences::within_string(&declaration.item_text().0);
214 let item_signature_occurrences =
215 IdentifierOccurrences::within_string(&declaration.signature_text().0);
216 let containing_range_vs_item_jaccard = jaccard_similarity(
217 containing_range_identifier_occurrences,
218 &item_source_occurrences,
219 );
220 let containing_range_vs_signature_jaccard = jaccard_similarity(
221 containing_range_identifier_occurrences,
222 &item_signature_occurrences,
223 );
224 let adjacent_vs_item_jaccard =
225 jaccard_similarity(adjacent_identifier_occurrences, &item_source_occurrences);
226 let adjacent_vs_signature_jaccard =
227 jaccard_similarity(adjacent_identifier_occurrences, &item_signature_occurrences);
228
229 let containing_range_vs_item_weighted_overlap = weighted_overlap_coefficient(
230 containing_range_identifier_occurrences,
231 &item_source_occurrences,
232 );
233 let containing_range_vs_signature_weighted_overlap = weighted_overlap_coefficient(
234 containing_range_identifier_occurrences,
235 &item_signature_occurrences,
236 );
237 let adjacent_vs_item_weighted_overlap =
238 weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_source_occurrences);
239 let adjacent_vs_signature_weighted_overlap =
240 weighted_overlap_coefficient(adjacent_identifier_occurrences, &item_signature_occurrences);
241
242 // TODO: Consider adding declaration_file_count
243 let score_components = ScoreComponents {
244 is_same_file,
245 is_referenced_nearby,
246 is_referenced_in_breadcrumb,
247 reference_line_distance,
248 declaration_line_distance,
249 declaration_line_distance_rank,
250 reference_count,
251 same_file_declaration_count,
252 declaration_count,
253 containing_range_vs_item_jaccard,
254 containing_range_vs_signature_jaccard,
255 adjacent_vs_item_jaccard,
256 adjacent_vs_signature_jaccard,
257 containing_range_vs_item_weighted_overlap,
258 containing_range_vs_signature_weighted_overlap,
259 adjacent_vs_item_weighted_overlap,
260 adjacent_vs_signature_weighted_overlap,
261 };
262
263 Some(ScoredSnippet {
264 identifier: identifier.clone(),
265 declaration: declaration,
266 scores: Scores::score(&score_components),
267 score_components,
268 })
269}
270
271#[derive(Clone, Debug, Serialize)]
272pub struct Scores {
273 pub signature: f32,
274 pub declaration: f32,
275}
276
277impl Scores {
278 fn score(components: &ScoreComponents) -> Scores {
279 // Score related to how likely this is the correct declaration, range 0 to 1
280 let accuracy_score = if components.is_same_file {
281 // TODO: use declaration_line_distance_rank
282 1.0 / components.same_file_declaration_count as f32
283 } else {
284 1.0 / components.declaration_count as f32
285 };
286
287 // Score related to the distance between the reference and cursor, range 0 to 1
288 let distance_score = if components.is_referenced_nearby {
289 1.0 / (1.0 + components.reference_line_distance as f32 / 10.0).powf(2.0)
290 } else {
291 // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
292 0.5
293 };
294
295 // For now instead of linear combination, the scores are just multiplied together.
296 let combined_score = 10.0 * accuracy_score * distance_score;
297
298 Scores {
299 signature: combined_score * components.containing_range_vs_signature_weighted_overlap,
300 // declaration score gets boosted both by being multiplied by 2 and by there being more
301 // weighted overlap.
302 declaration: 2.0
303 * combined_score
304 * components.containing_range_vs_item_weighted_overlap,
305 }
306 }
307}