1use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
2use collections::HashMap;
3use language::BufferSnapshot;
4use ordered_float::OrderedFloat;
5use serde::Serialize;
6use std::{cmp::Reverse, ops::Range, path::Path, sync::Arc};
7use strum::EnumIter;
8use text::{Point, ToPoint};
9
10use crate::{
11 CachedDeclarationPath, Declaration, EditPredictionExcerpt, Identifier,
12 imports::{Import, Imports, Module},
13 reference::{Reference, ReferenceRegion},
14 syntax_index::SyntaxIndexState,
15 text_similarity::{Occurrences, jaccard_similarity, weighted_overlap_coefficient},
16};
17
18const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
19
20#[derive(Clone, Debug, PartialEq, Eq)]
21pub struct EditPredictionScoreOptions {
22 pub omit_excerpt_overlaps: bool,
23}
24
25#[derive(Clone, Debug)]
26pub struct ScoredDeclaration {
27 /// identifier used by the local reference
28 pub identifier: Identifier,
29 pub declaration: Declaration,
30 pub components: DeclarationScoreComponents,
31}
32
33#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
34pub enum DeclarationStyle {
35 Signature,
36 Declaration,
37}
38
39#[derive(Clone, Debug, Serialize, Default)]
40pub struct DeclarationScores {
41 pub signature: f32,
42 pub declaration: f32,
43 pub retrieval: f32,
44}
45
46impl ScoredDeclaration {
47 /// Returns the score for this declaration with the specified style.
48 pub fn score(&self, style: DeclarationStyle) -> f32 {
49 // TODO: handle truncation
50
51 // Score related to how likely this is the correct declaration, range 0 to 1
52 let retrieval = self.retrieval_score();
53
54 // Score related to the distance between the reference and cursor, range 0 to 1
55 let distance_score = if self.components.is_referenced_nearby {
56 1.0 / (1.0 + self.components.reference_line_distance as f32 / 10.0).powf(2.0)
57 } else {
58 // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
59 0.5
60 };
61
62 // For now instead of linear combination, the scores are just multiplied together.
63 let combined_score = 10.0 * retrieval * distance_score;
64
65 match style {
66 DeclarationStyle::Signature => {
67 combined_score * self.components.excerpt_vs_signature_weighted_overlap
68 }
69 DeclarationStyle::Declaration => {
70 2.0 * combined_score * self.components.excerpt_vs_item_weighted_overlap
71 }
72 }
73 }
74
75 pub fn retrieval_score(&self) -> f32 {
76 if self.components.is_same_file {
77 10.0 / self.components.same_file_declaration_count as f32
78 } else if self.components.path_import_match_count > 0 {
79 3.0
80 } else if self.components.wildcard_path_import_match_count > 0 {
81 1.0
82 } else if self.components.normalized_import_similarity > 0.0 {
83 self.components.normalized_import_similarity
84 } else if self.components.normalized_wildcard_import_similarity > 0.0 {
85 0.5 * self.components.normalized_wildcard_import_similarity
86 } else {
87 1.0 / self.components.declaration_count as f32
88 }
89 }
90
91 pub fn size(&self, style: DeclarationStyle) -> usize {
92 match &self.declaration {
93 Declaration::File { declaration, .. } => match style {
94 DeclarationStyle::Signature => declaration.signature_range.len(),
95 DeclarationStyle::Declaration => declaration.text.len(),
96 },
97 Declaration::Buffer { declaration, .. } => match style {
98 DeclarationStyle::Signature => declaration.signature_range.len(),
99 DeclarationStyle::Declaration => declaration.item_range.len(),
100 },
101 }
102 }
103
104 pub fn score_density(&self, style: DeclarationStyle) -> f32 {
105 self.score(style) / self.size(style) as f32
106 }
107}
108
109pub fn scored_declarations(
110 options: &EditPredictionScoreOptions,
111 index: &SyntaxIndexState,
112 excerpt: &EditPredictionExcerpt,
113 excerpt_occurrences: &Occurrences,
114 adjacent_occurrences: &Occurrences,
115 imports: &Imports,
116 identifier_to_references: HashMap<Identifier, Vec<Reference>>,
117 cursor_offset: usize,
118 current_buffer: &BufferSnapshot,
119) -> Vec<ScoredDeclaration> {
120 let cursor_point = cursor_offset.to_point(¤t_buffer);
121
122 let mut wildcard_import_occurrences = Vec::new();
123 let mut wildcard_import_paths = Vec::new();
124 for wildcard_import in imports.wildcard_modules.iter() {
125 match wildcard_import {
126 Module::Namespace(namespace) => {
127 wildcard_import_occurrences.push(namespace.occurrences())
128 }
129 Module::SourceExact(path) => wildcard_import_paths.push(path),
130 Module::SourceFuzzy(path) => {
131 wildcard_import_occurrences.push(Occurrences::from_path(&path))
132 }
133 }
134 }
135
136 let mut declarations = identifier_to_references
137 .into_iter()
138 .flat_map(|(identifier, references)| {
139 let mut import_occurrences = Vec::new();
140 let mut import_paths = Vec::new();
141 let mut found_external_identifier: Option<&Identifier> = None;
142
143 if let Some(imports) = imports.identifier_to_imports.get(&identifier) {
144 // only use alias when it's the only import, could be generalized if some language
145 // has overlapping aliases
146 //
147 // TODO: when an aliased declaration is included in the prompt, should include the
148 // aliasing in the prompt.
149 //
150 // TODO: For SourceFuzzy consider having componentwise comparison that pays
151 // attention to ordering.
152 if let [
153 Import::Alias {
154 module,
155 external_identifier,
156 },
157 ] = imports.as_slice()
158 {
159 match module {
160 Module::Namespace(namespace) => {
161 import_occurrences.push(namespace.occurrences())
162 }
163 Module::SourceExact(path) => import_paths.push(path),
164 Module::SourceFuzzy(path) => {
165 import_occurrences.push(Occurrences::from_path(&path))
166 }
167 }
168 found_external_identifier = Some(&external_identifier);
169 } else {
170 for import in imports {
171 match import {
172 Import::Direct { module } => match module {
173 Module::Namespace(namespace) => {
174 import_occurrences.push(namespace.occurrences())
175 }
176 Module::SourceExact(path) => import_paths.push(path),
177 Module::SourceFuzzy(path) => {
178 import_occurrences.push(Occurrences::from_path(&path))
179 }
180 },
181 Import::Alias { .. } => {}
182 }
183 }
184 }
185 }
186
187 let identifier_to_lookup = found_external_identifier.unwrap_or(&identifier);
188 // TODO: update this to be able to return more declarations? Especially if there is the
189 // ability to quickly filter a large list (based on imports)
190 let declarations = index
191 .declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(
192 &identifier_to_lookup,
193 );
194 let declaration_count = declarations.len();
195
196 if declaration_count == 0 {
197 return Vec::new();
198 }
199
200 // TODO: option to filter out other candidates when same file / import match
201 let mut checked_declarations = Vec::new();
202 for (declaration_id, declaration) in declarations {
203 match declaration {
204 Declaration::Buffer {
205 buffer_id,
206 declaration: buffer_declaration,
207 ..
208 } => {
209 if buffer_id == ¤t_buffer.remote_id() {
210 let already_included_in_prompt =
211 range_intersection(&buffer_declaration.item_range, &excerpt.range)
212 .is_some()
213 || excerpt.parent_declarations.iter().any(
214 |(excerpt_parent, _)| excerpt_parent == &declaration_id,
215 );
216 if !options.omit_excerpt_overlaps || !already_included_in_prompt {
217 let declaration_line = buffer_declaration
218 .item_range
219 .start
220 .to_point(current_buffer)
221 .row;
222 let declaration_line_distance = (cursor_point.row as i32
223 - declaration_line as i32)
224 .unsigned_abs();
225 checked_declarations.push(CheckedDeclaration {
226 declaration,
227 same_file_line_distance: Some(declaration_line_distance),
228 path_import_match_count: 0,
229 wildcard_path_import_match_count: 0,
230 });
231 }
232 continue;
233 } else {
234 }
235 }
236 Declaration::File { .. } => {}
237 }
238 let declaration_path = declaration.cached_path();
239 let path_import_match_count = import_paths
240 .iter()
241 .filter(|import_path| {
242 declaration_path_matches_import(&declaration_path, import_path)
243 })
244 .count();
245 let wildcard_path_import_match_count = wildcard_import_paths
246 .iter()
247 .filter(|import_path| {
248 declaration_path_matches_import(&declaration_path, import_path)
249 })
250 .count();
251 checked_declarations.push(CheckedDeclaration {
252 declaration,
253 same_file_line_distance: None,
254 path_import_match_count,
255 wildcard_path_import_match_count,
256 });
257 }
258
259 let mut max_import_similarity = 0.0;
260 let mut max_wildcard_import_similarity = 0.0;
261
262 let mut scored_declarations_for_identifier = checked_declarations
263 .into_iter()
264 .map(|checked_declaration| {
265 let same_file_declaration_count =
266 index.file_declaration_count(checked_declaration.declaration);
267
268 let declaration = score_declaration(
269 &identifier,
270 &references,
271 checked_declaration,
272 same_file_declaration_count,
273 declaration_count,
274 &excerpt_occurrences,
275 &adjacent_occurrences,
276 &import_occurrences,
277 &wildcard_import_occurrences,
278 cursor_point,
279 current_buffer,
280 );
281
282 if declaration.components.import_similarity > max_import_similarity {
283 max_import_similarity = declaration.components.import_similarity;
284 }
285
286 if declaration.components.wildcard_import_similarity
287 > max_wildcard_import_similarity
288 {
289 max_wildcard_import_similarity =
290 declaration.components.wildcard_import_similarity;
291 }
292
293 declaration
294 })
295 .collect::<Vec<_>>();
296
297 if max_import_similarity > 0.0 || max_wildcard_import_similarity > 0.0 {
298 for declaration in scored_declarations_for_identifier.iter_mut() {
299 if max_import_similarity > 0.0 {
300 declaration.components.max_import_similarity = max_import_similarity;
301 declaration.components.normalized_import_similarity =
302 declaration.components.import_similarity / max_import_similarity;
303 }
304 if max_wildcard_import_similarity > 0.0 {
305 declaration.components.normalized_wildcard_import_similarity =
306 declaration.components.wildcard_import_similarity
307 / max_wildcard_import_similarity;
308 }
309 }
310 }
311
312 scored_declarations_for_identifier
313 })
314 .collect::<Vec<_>>();
315
316 declarations.sort_unstable_by_key(|declaration| {
317 let score_density = declaration
318 .score_density(DeclarationStyle::Declaration)
319 .max(declaration.score_density(DeclarationStyle::Signature));
320 Reverse(OrderedFloat(score_density))
321 });
322
323 declarations
324}
325
326struct CheckedDeclaration<'a> {
327 declaration: &'a Declaration,
328 same_file_line_distance: Option<u32>,
329 path_import_match_count: usize,
330 wildcard_path_import_match_count: usize,
331}
332
333fn declaration_path_matches_import(
334 declaration_path: &CachedDeclarationPath,
335 import_path: &Arc<Path>,
336) -> bool {
337 if import_path.is_absolute() {
338 declaration_path.equals_absolute_path(import_path)
339 } else {
340 declaration_path.ends_with_posix_path(import_path)
341 }
342}
343
344fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
345 let start = a.start.clone().max(b.start.clone());
346 let end = a.end.clone().min(b.end.clone());
347 if start < end {
348 Some(Range { start, end })
349 } else {
350 None
351 }
352}
353
354fn score_declaration(
355 identifier: &Identifier,
356 references: &[Reference],
357 checked_declaration: CheckedDeclaration,
358 same_file_declaration_count: usize,
359 declaration_count: usize,
360 excerpt_occurrences: &Occurrences,
361 adjacent_occurrences: &Occurrences,
362 import_occurrences: &[Occurrences],
363 wildcard_import_occurrences: &[Occurrences],
364 cursor: Point,
365 current_buffer: &BufferSnapshot,
366) -> ScoredDeclaration {
367 let CheckedDeclaration {
368 declaration,
369 same_file_line_distance,
370 path_import_match_count,
371 wildcard_path_import_match_count,
372 } = checked_declaration;
373
374 let is_referenced_nearby = references
375 .iter()
376 .any(|r| r.region == ReferenceRegion::Nearby);
377 let is_referenced_in_breadcrumb = references
378 .iter()
379 .any(|r| r.region == ReferenceRegion::Breadcrumb);
380 let reference_count = references.len();
381 let reference_line_distance = references
382 .iter()
383 .map(|r| {
384 let reference_line = r.range.start.to_point(current_buffer).row as i32;
385 (cursor.row as i32 - reference_line).unsigned_abs()
386 })
387 .min()
388 .unwrap();
389
390 let is_same_file = same_file_line_distance.is_some();
391 let declaration_line_distance = same_file_line_distance.unwrap_or(u32::MAX);
392
393 let item_source_occurrences = Occurrences::within_string(&declaration.item_text().0);
394 let item_signature_occurrences = Occurrences::within_string(&declaration.signature_text().0);
395 let excerpt_vs_item_jaccard = jaccard_similarity(excerpt_occurrences, &item_source_occurrences);
396 let excerpt_vs_signature_jaccard =
397 jaccard_similarity(excerpt_occurrences, &item_signature_occurrences);
398 let adjacent_vs_item_jaccard =
399 jaccard_similarity(adjacent_occurrences, &item_source_occurrences);
400 let adjacent_vs_signature_jaccard =
401 jaccard_similarity(adjacent_occurrences, &item_signature_occurrences);
402
403 let excerpt_vs_item_weighted_overlap =
404 weighted_overlap_coefficient(excerpt_occurrences, &item_source_occurrences);
405 let excerpt_vs_signature_weighted_overlap =
406 weighted_overlap_coefficient(excerpt_occurrences, &item_signature_occurrences);
407 let adjacent_vs_item_weighted_overlap =
408 weighted_overlap_coefficient(adjacent_occurrences, &item_source_occurrences);
409 let adjacent_vs_signature_weighted_overlap =
410 weighted_overlap_coefficient(adjacent_occurrences, &item_signature_occurrences);
411
412 let mut import_similarity = 0f32;
413 let mut wildcard_import_similarity = 0f32;
414 if !import_occurrences.is_empty() || !wildcard_import_occurrences.is_empty() {
415 let cached_path = declaration.cached_path();
416 let path_occurrences = Occurrences::from_worktree_path(
417 cached_path
418 .worktree_abs_path
419 .file_name()
420 .map(|f| f.to_string_lossy()),
421 &cached_path.rel_path,
422 );
423 import_similarity = import_occurrences
424 .iter()
425 .map(|namespace_occurrences| {
426 OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences))
427 })
428 .max()
429 .map(|similarity| similarity.into_inner())
430 .unwrap_or_default();
431
432 // TODO: Consider something other than max
433 wildcard_import_similarity = wildcard_import_occurrences
434 .iter()
435 .map(|namespace_occurrences| {
436 OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences))
437 })
438 .max()
439 .map(|similarity| similarity.into_inner())
440 .unwrap_or_default();
441 }
442
443 // TODO: Consider adding declaration_file_count
444 let score_components = DeclarationScoreComponents {
445 is_same_file,
446 is_referenced_nearby,
447 is_referenced_in_breadcrumb,
448 reference_line_distance,
449 declaration_line_distance,
450 reference_count,
451 same_file_declaration_count,
452 declaration_count,
453 excerpt_vs_item_jaccard,
454 excerpt_vs_signature_jaccard,
455 adjacent_vs_item_jaccard,
456 adjacent_vs_signature_jaccard,
457 excerpt_vs_item_weighted_overlap,
458 excerpt_vs_signature_weighted_overlap,
459 adjacent_vs_item_weighted_overlap,
460 adjacent_vs_signature_weighted_overlap,
461 path_import_match_count,
462 wildcard_path_import_match_count,
463 import_similarity,
464 max_import_similarity: 0.0,
465 normalized_import_similarity: 0.0,
466 wildcard_import_similarity,
467 normalized_wildcard_import_similarity: 0.0,
468 };
469
470 ScoredDeclaration {
471 identifier: identifier.clone(),
472 declaration: declaration.clone(),
473 components: score_components,
474 }
475}
476
477#[cfg(test)]
478mod test {
479 use super::*;
480
481 #[test]
482 fn test_declaration_path_matches() {
483 let declaration_path =
484 CachedDeclarationPath::new_for_test("/home/user/project", "src/maths.ts");
485
486 assert!(declaration_path_matches_import(
487 &declaration_path,
488 &Path::new("maths.ts").into()
489 ));
490
491 assert!(declaration_path_matches_import(
492 &declaration_path,
493 &Path::new("project/src/maths.ts").into()
494 ));
495
496 assert!(declaration_path_matches_import(
497 &declaration_path,
498 &Path::new("user/project/src/maths.ts").into()
499 ));
500
501 assert!(declaration_path_matches_import(
502 &declaration_path,
503 &Path::new("/home/user/project/src/maths.ts").into()
504 ));
505
506 assert!(!declaration_path_matches_import(
507 &declaration_path,
508 &Path::new("other.ts").into()
509 ));
510
511 assert!(!declaration_path_matches_import(
512 &declaration_path,
513 &Path::new("/home/user/project/src/other.ts").into()
514 ));
515 }
516}