1use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
2use itertools::Itertools as _;
3use language::BufferSnapshot;
4use ordered_float::OrderedFloat;
5use serde::Serialize;
6use std::{cmp::Reverse, collections::HashMap, ops::Range};
7use strum::EnumIter;
8use text::{Point, ToPoint};
9
10use crate::{
11 Declaration, EditPredictionExcerpt, Identifier,
12 reference::{Reference, ReferenceRegion},
13 syntax_index::SyntaxIndexState,
14 text_similarity::{Occurrences, jaccard_similarity, weighted_overlap_coefficient},
15};
16
17const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
18
19#[derive(Clone, Debug)]
20pub struct ScoredDeclaration {
21 pub identifier: Identifier,
22 pub declaration: Declaration,
23 pub score_components: DeclarationScoreComponents,
24 pub scores: DeclarationScores,
25}
26
27#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
28pub enum DeclarationStyle {
29 Signature,
30 Declaration,
31}
32
33impl ScoredDeclaration {
34 /// Returns the score for this declaration with the specified style.
35 pub fn score(&self, style: DeclarationStyle) -> f32 {
36 match style {
37 DeclarationStyle::Signature => self.scores.signature,
38 DeclarationStyle::Declaration => self.scores.declaration,
39 }
40 }
41
42 pub fn size(&self, style: DeclarationStyle) -> usize {
43 match &self.declaration {
44 Declaration::File { declaration, .. } => match style {
45 DeclarationStyle::Signature => declaration.signature_range.len(),
46 DeclarationStyle::Declaration => declaration.text.len(),
47 },
48 Declaration::Buffer { declaration, .. } => match style {
49 DeclarationStyle::Signature => declaration.signature_range.len(),
50 DeclarationStyle::Declaration => declaration.item_range.len(),
51 },
52 }
53 }
54
55 pub fn score_density(&self, style: DeclarationStyle) -> f32 {
56 self.score(style) / (self.size(style)) as f32
57 }
58}
59
60pub fn scored_declarations(
61 index: &SyntaxIndexState,
62 excerpt: &EditPredictionExcerpt,
63 excerpt_occurrences: &Occurrences,
64 adjacent_occurrences: &Occurrences,
65 identifier_to_references: HashMap<Identifier, Vec<Reference>>,
66 cursor_offset: usize,
67 current_buffer: &BufferSnapshot,
68) -> Vec<ScoredDeclaration> {
69 let cursor_point = cursor_offset.to_point(¤t_buffer);
70
71 let mut declarations = identifier_to_references
72 .into_iter()
73 .flat_map(|(identifier, references)| {
74 let declarations =
75 index.declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier);
76 let declaration_count = declarations.len();
77
78 declarations
79 .into_iter()
80 .filter_map(|(declaration_id, declaration)| match declaration {
81 Declaration::Buffer {
82 buffer_id,
83 declaration: buffer_declaration,
84 ..
85 } => {
86 let is_same_file = buffer_id == ¤t_buffer.remote_id();
87
88 if is_same_file {
89 let overlaps_excerpt =
90 range_intersection(&buffer_declaration.item_range, &excerpt.range)
91 .is_some();
92 if overlaps_excerpt
93 || excerpt
94 .parent_declarations
95 .iter()
96 .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id)
97 {
98 None
99 } else {
100 let declaration_line = buffer_declaration
101 .item_range
102 .start
103 .to_point(current_buffer)
104 .row;
105 Some((
106 true,
107 (cursor_point.row as i32 - declaration_line as i32)
108 .unsigned_abs(),
109 declaration,
110 ))
111 }
112 } else {
113 Some((false, u32::MAX, declaration))
114 }
115 }
116 Declaration::File { .. } => {
117 // We can assume that a file declaration is in a different file,
118 // because the current one must be open
119 Some((false, u32::MAX, declaration))
120 }
121 })
122 .sorted_by_key(|&(_, distance, _)| distance)
123 .enumerate()
124 .map(
125 |(
126 declaration_line_distance_rank,
127 (is_same_file, declaration_line_distance, declaration),
128 )| {
129 let same_file_declaration_count = index.file_declaration_count(declaration);
130
131 score_declaration(
132 &identifier,
133 &references,
134 declaration.clone(),
135 is_same_file,
136 declaration_line_distance,
137 declaration_line_distance_rank,
138 same_file_declaration_count,
139 declaration_count,
140 &excerpt_occurrences,
141 &adjacent_occurrences,
142 cursor_point,
143 current_buffer,
144 )
145 },
146 )
147 .collect::<Vec<_>>()
148 })
149 .flatten()
150 .collect::<Vec<_>>();
151
152 declarations.sort_unstable_by_key(|declaration| {
153 let score_density = declaration
154 .score_density(DeclarationStyle::Declaration)
155 .max(declaration.score_density(DeclarationStyle::Signature));
156 Reverse(OrderedFloat(score_density))
157 });
158
159 declarations
160}
161
162fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
163 let start = a.start.clone().max(b.start.clone());
164 let end = a.end.clone().min(b.end.clone());
165 if start < end {
166 Some(Range { start, end })
167 } else {
168 None
169 }
170}
171
172fn score_declaration(
173 identifier: &Identifier,
174 references: &[Reference],
175 declaration: Declaration,
176 is_same_file: bool,
177 declaration_line_distance: u32,
178 declaration_line_distance_rank: usize,
179 same_file_declaration_count: usize,
180 declaration_count: usize,
181 excerpt_occurrences: &Occurrences,
182 adjacent_occurrences: &Occurrences,
183 cursor: Point,
184 current_buffer: &BufferSnapshot,
185) -> Option<ScoredDeclaration> {
186 let is_referenced_nearby = references
187 .iter()
188 .any(|r| r.region == ReferenceRegion::Nearby);
189 let is_referenced_in_breadcrumb = references
190 .iter()
191 .any(|r| r.region == ReferenceRegion::Breadcrumb);
192 let reference_count = references.len();
193 let reference_line_distance = references
194 .iter()
195 .map(|r| {
196 let reference_line = r.range.start.to_point(current_buffer).row as i32;
197 (cursor.row as i32 - reference_line).unsigned_abs()
198 })
199 .min()
200 .unwrap();
201
202 let item_source_occurrences = Occurrences::within_string(&declaration.item_text().0);
203 let item_signature_occurrences = Occurrences::within_string(&declaration.signature_text().0);
204 let excerpt_vs_item_jaccard = jaccard_similarity(excerpt_occurrences, &item_source_occurrences);
205 let excerpt_vs_signature_jaccard =
206 jaccard_similarity(excerpt_occurrences, &item_signature_occurrences);
207 let adjacent_vs_item_jaccard =
208 jaccard_similarity(adjacent_occurrences, &item_source_occurrences);
209 let adjacent_vs_signature_jaccard =
210 jaccard_similarity(adjacent_occurrences, &item_signature_occurrences);
211
212 let excerpt_vs_item_weighted_overlap =
213 weighted_overlap_coefficient(excerpt_occurrences, &item_source_occurrences);
214 let excerpt_vs_signature_weighted_overlap =
215 weighted_overlap_coefficient(excerpt_occurrences, &item_signature_occurrences);
216 let adjacent_vs_item_weighted_overlap =
217 weighted_overlap_coefficient(adjacent_occurrences, &item_source_occurrences);
218 let adjacent_vs_signature_weighted_overlap =
219 weighted_overlap_coefficient(adjacent_occurrences, &item_signature_occurrences);
220
221 // TODO: Consider adding declaration_file_count
222 let score_components = DeclarationScoreComponents {
223 is_same_file,
224 is_referenced_nearby,
225 is_referenced_in_breadcrumb,
226 reference_line_distance,
227 declaration_line_distance,
228 declaration_line_distance_rank,
229 reference_count,
230 same_file_declaration_count,
231 declaration_count,
232 excerpt_vs_item_jaccard,
233 excerpt_vs_signature_jaccard,
234 adjacent_vs_item_jaccard,
235 adjacent_vs_signature_jaccard,
236 excerpt_vs_item_weighted_overlap,
237 excerpt_vs_signature_weighted_overlap,
238 adjacent_vs_item_weighted_overlap,
239 adjacent_vs_signature_weighted_overlap,
240 };
241
242 Some(ScoredDeclaration {
243 identifier: identifier.clone(),
244 declaration: declaration,
245 scores: DeclarationScores::score(&score_components),
246 score_components,
247 })
248}
249
250#[derive(Clone, Debug, Serialize)]
251pub struct DeclarationScores {
252 pub signature: f32,
253 pub declaration: f32,
254}
255
256impl DeclarationScores {
257 fn score(components: &DeclarationScoreComponents) -> DeclarationScores {
258 // TODO: handle truncation
259
260 // Score related to how likely this is the correct declaration, range 0 to 1
261 let accuracy_score = if components.is_same_file {
262 // TODO: use declaration_line_distance_rank
263 1.0 / components.same_file_declaration_count as f32
264 } else {
265 1.0 / components.declaration_count as f32
266 };
267
268 // Score related to the distance between the reference and cursor, range 0 to 1
269 let distance_score = if components.is_referenced_nearby {
270 1.0 / (1.0 + components.reference_line_distance as f32 / 10.0).powf(2.0)
271 } else {
272 // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
273 0.5
274 };
275
276 // For now instead of linear combination, the scores are just multiplied together.
277 let combined_score = 10.0 * accuracy_score * distance_score;
278
279 DeclarationScores {
280 signature: combined_score * components.excerpt_vs_signature_weighted_overlap,
281 // declaration score gets boosted both by being multiplied by 2 and by there being more
282 // weighted overlap.
283 declaration: 2.0 * combined_score * components.excerpt_vs_item_weighted_overlap,
284 }
285 }
286}