1use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
2use collections::HashMap;
3use itertools::Itertools as _;
4use language::BufferSnapshot;
5use ordered_float::OrderedFloat;
6use serde::Serialize;
7use std::{cmp::Reverse, ops::Range};
8use strum::EnumIter;
9use text::{Point, ToPoint};
10
11use crate::{
12 Declaration, EditPredictionExcerpt, Identifier,
13 reference::{Reference, ReferenceRegion},
14 syntax_index::SyntaxIndexState,
15 text_similarity::{Occurrences, jaccard_similarity, weighted_overlap_coefficient},
16};
17
18const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
19
20#[derive(Clone, Debug)]
21pub struct ScoredDeclaration {
22 pub identifier: Identifier,
23 pub declaration: Declaration,
24 pub score_components: DeclarationScoreComponents,
25 pub scores: DeclarationScores,
26}
27
28#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
29pub enum DeclarationStyle {
30 Signature,
31 Declaration,
32}
33
34impl ScoredDeclaration {
35 /// Returns the score for this declaration with the specified style.
36 pub fn score(&self, style: DeclarationStyle) -> f32 {
37 match style {
38 DeclarationStyle::Signature => self.scores.signature,
39 DeclarationStyle::Declaration => self.scores.declaration,
40 }
41 }
42
43 pub fn size(&self, style: DeclarationStyle) -> usize {
44 match &self.declaration {
45 Declaration::File { declaration, .. } => match style {
46 DeclarationStyle::Signature => declaration.signature_range.len(),
47 DeclarationStyle::Declaration => declaration.text.len(),
48 },
49 Declaration::Buffer { declaration, .. } => match style {
50 DeclarationStyle::Signature => declaration.signature_range.len(),
51 DeclarationStyle::Declaration => declaration.item_range.len(),
52 },
53 }
54 }
55
56 pub fn score_density(&self, style: DeclarationStyle) -> f32 {
57 self.score(style) / (self.size(style)) as f32
58 }
59}
60
61pub fn scored_declarations(
62 index: &SyntaxIndexState,
63 excerpt: &EditPredictionExcerpt,
64 excerpt_occurrences: &Occurrences,
65 adjacent_occurrences: &Occurrences,
66 identifier_to_references: HashMap<Identifier, Vec<Reference>>,
67 cursor_offset: usize,
68 current_buffer: &BufferSnapshot,
69) -> Vec<ScoredDeclaration> {
70 let cursor_point = cursor_offset.to_point(¤t_buffer);
71
72 let mut declarations = identifier_to_references
73 .into_iter()
74 .flat_map(|(identifier, references)| {
75 let declarations =
76 index.declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier);
77 let declaration_count = declarations.len();
78
79 declarations
80 .into_iter()
81 .filter_map(|(declaration_id, declaration)| match declaration {
82 Declaration::Buffer {
83 buffer_id,
84 declaration: buffer_declaration,
85 ..
86 } => {
87 let is_same_file = buffer_id == ¤t_buffer.remote_id();
88
89 if is_same_file {
90 let overlaps_excerpt =
91 range_intersection(&buffer_declaration.item_range, &excerpt.range)
92 .is_some();
93 if overlaps_excerpt
94 || excerpt
95 .parent_declarations
96 .iter()
97 .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id)
98 {
99 None
100 } else {
101 let declaration_line = buffer_declaration
102 .item_range
103 .start
104 .to_point(current_buffer)
105 .row;
106 Some((
107 true,
108 (cursor_point.row as i32 - declaration_line as i32)
109 .unsigned_abs(),
110 declaration,
111 ))
112 }
113 } else {
114 Some((false, u32::MAX, declaration))
115 }
116 }
117 Declaration::File { .. } => {
118 // We can assume that a file declaration is in a different file,
119 // because the current one must be open
120 Some((false, u32::MAX, declaration))
121 }
122 })
123 .sorted_by_key(|&(_, distance, _)| distance)
124 .enumerate()
125 .map(
126 |(
127 declaration_line_distance_rank,
128 (is_same_file, declaration_line_distance, declaration),
129 )| {
130 let same_file_declaration_count = index.file_declaration_count(declaration);
131
132 score_declaration(
133 &identifier,
134 &references,
135 declaration.clone(),
136 is_same_file,
137 declaration_line_distance,
138 declaration_line_distance_rank,
139 same_file_declaration_count,
140 declaration_count,
141 &excerpt_occurrences,
142 &adjacent_occurrences,
143 cursor_point,
144 current_buffer,
145 )
146 },
147 )
148 .collect::<Vec<_>>()
149 })
150 .flatten()
151 .collect::<Vec<_>>();
152
153 declarations.sort_unstable_by_key(|declaration| {
154 let score_density = declaration
155 .score_density(DeclarationStyle::Declaration)
156 .max(declaration.score_density(DeclarationStyle::Signature));
157 Reverse(OrderedFloat(score_density))
158 });
159
160 declarations
161}
162
163fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
164 let start = a.start.clone().max(b.start.clone());
165 let end = a.end.clone().min(b.end.clone());
166 if start < end {
167 Some(Range { start, end })
168 } else {
169 None
170 }
171}
172
173fn score_declaration(
174 identifier: &Identifier,
175 references: &[Reference],
176 declaration: Declaration,
177 is_same_file: bool,
178 declaration_line_distance: u32,
179 declaration_line_distance_rank: usize,
180 same_file_declaration_count: usize,
181 declaration_count: usize,
182 excerpt_occurrences: &Occurrences,
183 adjacent_occurrences: &Occurrences,
184 cursor: Point,
185 current_buffer: &BufferSnapshot,
186) -> Option<ScoredDeclaration> {
187 let is_referenced_nearby = references
188 .iter()
189 .any(|r| r.region == ReferenceRegion::Nearby);
190 let is_referenced_in_breadcrumb = references
191 .iter()
192 .any(|r| r.region == ReferenceRegion::Breadcrumb);
193 let reference_count = references.len();
194 let reference_line_distance = references
195 .iter()
196 .map(|r| {
197 let reference_line = r.range.start.to_point(current_buffer).row as i32;
198 (cursor.row as i32 - reference_line).unsigned_abs()
199 })
200 .min()
201 .unwrap();
202
203 let item_source_occurrences = Occurrences::within_string(&declaration.item_text().0);
204 let item_signature_occurrences = Occurrences::within_string(&declaration.signature_text().0);
205 let excerpt_vs_item_jaccard = jaccard_similarity(excerpt_occurrences, &item_source_occurrences);
206 let excerpt_vs_signature_jaccard =
207 jaccard_similarity(excerpt_occurrences, &item_signature_occurrences);
208 let adjacent_vs_item_jaccard =
209 jaccard_similarity(adjacent_occurrences, &item_source_occurrences);
210 let adjacent_vs_signature_jaccard =
211 jaccard_similarity(adjacent_occurrences, &item_signature_occurrences);
212
213 let excerpt_vs_item_weighted_overlap =
214 weighted_overlap_coefficient(excerpt_occurrences, &item_source_occurrences);
215 let excerpt_vs_signature_weighted_overlap =
216 weighted_overlap_coefficient(excerpt_occurrences, &item_signature_occurrences);
217 let adjacent_vs_item_weighted_overlap =
218 weighted_overlap_coefficient(adjacent_occurrences, &item_source_occurrences);
219 let adjacent_vs_signature_weighted_overlap =
220 weighted_overlap_coefficient(adjacent_occurrences, &item_signature_occurrences);
221
222 // TODO: Consider adding declaration_file_count
223 let score_components = DeclarationScoreComponents {
224 is_same_file,
225 is_referenced_nearby,
226 is_referenced_in_breadcrumb,
227 reference_line_distance,
228 declaration_line_distance,
229 declaration_line_distance_rank,
230 reference_count,
231 same_file_declaration_count,
232 declaration_count,
233 excerpt_vs_item_jaccard,
234 excerpt_vs_signature_jaccard,
235 adjacent_vs_item_jaccard,
236 adjacent_vs_signature_jaccard,
237 excerpt_vs_item_weighted_overlap,
238 excerpt_vs_signature_weighted_overlap,
239 adjacent_vs_item_weighted_overlap,
240 adjacent_vs_signature_weighted_overlap,
241 };
242
243 Some(ScoredDeclaration {
244 identifier: identifier.clone(),
245 declaration: declaration,
246 scores: DeclarationScores::score(&score_components),
247 score_components,
248 })
249}
250
251#[derive(Clone, Debug, Serialize)]
252pub struct DeclarationScores {
253 pub signature: f32,
254 pub declaration: f32,
255 pub retrieval: f32,
256}
257
258impl DeclarationScores {
259 fn score(components: &DeclarationScoreComponents) -> DeclarationScores {
260 // TODO: handle truncation
261
262 // Score related to how likely this is the correct declaration, range 0 to 1
263 let retrieval = if components.is_same_file {
264 // TODO: use declaration_line_distance_rank
265 1.0 / components.same_file_declaration_count as f32
266 } else {
267 1.0 / components.declaration_count as f32
268 };
269
270 // Score related to the distance between the reference and cursor, range 0 to 1
271 let distance_score = if components.is_referenced_nearby {
272 1.0 / (1.0 + components.reference_line_distance as f32 / 10.0).powf(2.0)
273 } else {
274 // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
275 0.5
276 };
277
278 // For now instead of linear combination, the scores are just multiplied together.
279 let combined_score = 10.0 * retrieval * distance_score;
280
281 DeclarationScores {
282 signature: combined_score * components.excerpt_vs_signature_weighted_overlap,
283 // declaration score gets boosted both by being multiplied by 2 and by there being more
284 // weighted overlap.
285 declaration: 2.0 * combined_score * components.excerpt_vs_item_weighted_overlap,
286 retrieval,
287 }
288 }
289}