1use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
2use collections::HashMap;
3use language::BufferSnapshot;
4use ordered_float::OrderedFloat;
5use project::ProjectEntryId;
6use serde::Serialize;
7use std::{cmp::Reverse, ops::Range, path::Path, sync::Arc};
8use strum::EnumIter;
9use text::{Point, ToPoint};
10use util::RangeExt as _;
11
12use crate::{
13 declaration::{CachedDeclarationPath, Declaration, Identifier},
14 excerpt::EditPredictionExcerpt,
15 imports::{Import, Imports, Module},
16 reference::{Reference, ReferenceRegion},
17 syntax_index::SyntaxIndexState,
18 text_similarity::{
19 IdentifierParts, OccurrenceSource, Occurrences, Similarity as _, WeightedSimilarity as _,
20 },
21};
22
23const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
24
25#[derive(Clone, Debug, PartialEq, Eq)]
26pub struct EditPredictionScoreOptions {
27 pub omit_excerpt_overlaps: bool,
28}
29
30#[derive(Clone, Debug)]
31pub struct ScoredDeclaration {
32 /// identifier used by the local reference
33 pub identifier: Identifier,
34 pub declaration: Declaration,
35 pub components: DeclarationScoreComponents,
36}
37
38#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
39pub enum DeclarationStyle {
40 Signature,
41 Declaration,
42}
43
44#[derive(Clone, Debug, Serialize, Default)]
45pub struct DeclarationScores {
46 pub signature: f32,
47 pub declaration: f32,
48 pub retrieval: f32,
49}
50
51impl ScoredDeclaration {
52 /// Returns the score for this declaration with the specified style.
53 pub fn score(&self, style: DeclarationStyle) -> f32 {
54 // TODO: handle truncation
55
56 // Score related to how likely this is the correct declaration, range 0 to 1
57 let retrieval = self.retrieval_score();
58
59 // Score related to the distance between the reference and cursor, range 0 to 1
60 let distance_score = if self.components.is_referenced_nearby {
61 1.0 / (1.0 + self.components.reference_line_distance as f32 / 10.0).powf(2.0)
62 } else {
63 // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
64 0.5
65 };
66
67 // For now instead of linear combination, the scores are just multiplied together.
68 let combined_score = 10.0 * retrieval * distance_score;
69
70 match style {
71 DeclarationStyle::Signature => {
72 combined_score * self.components.excerpt_vs_signature_weighted_overlap
73 }
74 DeclarationStyle::Declaration => {
75 2.0 * combined_score * self.components.excerpt_vs_item_weighted_overlap
76 }
77 }
78 }
79
80 pub fn retrieval_score(&self) -> f32 {
81 let mut score = if self.components.is_same_file {
82 10.0 / self.components.same_file_declaration_count as f32
83 } else if self.components.path_import_match_count > 0 {
84 3.0
85 } else if self.components.wildcard_path_import_match_count > 0 {
86 1.0
87 } else if self.components.normalized_import_similarity > 0.0 {
88 self.components.normalized_import_similarity
89 } else if self.components.normalized_wildcard_import_similarity > 0.0 {
90 0.5 * self.components.normalized_wildcard_import_similarity
91 } else {
92 1.0 / self.components.declaration_count as f32
93 };
94 score *= 1. + self.components.included_by_others as f32 / 2.;
95 score *= 1. + self.components.includes_others as f32 / 4.;
96 score
97 }
98
99 pub fn size(&self, style: DeclarationStyle) -> usize {
100 match &self.declaration {
101 Declaration::File { declaration, .. } => match style {
102 DeclarationStyle::Signature => declaration.signature_range.len(),
103 DeclarationStyle::Declaration => declaration.text.len(),
104 },
105 Declaration::Buffer { declaration, .. } => match style {
106 DeclarationStyle::Signature => declaration.signature_range.len(),
107 DeclarationStyle::Declaration => declaration.item_range.len(),
108 },
109 }
110 }
111
112 pub fn score_density(&self, style: DeclarationStyle) -> f32 {
113 self.score(style) / self.size(style) as f32
114 }
115}
116
117pub fn scored_declarations(
118 options: &EditPredictionScoreOptions,
119 index: &SyntaxIndexState,
120 excerpt: &EditPredictionExcerpt,
121 excerpt_occurrences: &Occurrences<IdentifierParts>,
122 adjacent_occurrences: &Occurrences<IdentifierParts>,
123 imports: &Imports,
124 identifier_to_references: HashMap<Identifier, Vec<Reference>>,
125 cursor_offset: usize,
126 current_buffer: &BufferSnapshot,
127) -> Vec<ScoredDeclaration> {
128 let cursor_point = cursor_offset.to_point(¤t_buffer);
129
130 let mut wildcard_import_occurrences = Vec::new();
131 let mut wildcard_import_paths = Vec::new();
132 for wildcard_import in imports.wildcard_modules.iter() {
133 match wildcard_import {
134 Module::Namespace(namespace) => {
135 wildcard_import_occurrences.push(namespace.occurrences())
136 }
137 Module::SourceExact(path) => wildcard_import_paths.push(path),
138 Module::SourceFuzzy(path) => wildcard_import_occurrences.push(Occurrences::new(
139 IdentifierParts::occurrences_in_path(&path),
140 )),
141 }
142 }
143
144 let mut scored_declarations = Vec::new();
145 let mut project_entry_id_to_outline_ranges: HashMap<ProjectEntryId, Vec<Range<usize>>> =
146 HashMap::default();
147 for (identifier, references) in identifier_to_references {
148 let mut import_occurrences = Vec::new();
149 let mut import_paths = Vec::new();
150 let mut found_external_identifier: Option<&Identifier> = None;
151
152 if let Some(imports) = imports.identifier_to_imports.get(&identifier) {
153 // only use alias when it's the only import, could be generalized if some language
154 // has overlapping aliases
155 //
156 // TODO: when an aliased declaration is included in the prompt, should include the
157 // aliasing in the prompt.
158 //
159 // TODO: For SourceFuzzy consider having componentwise comparison that pays
160 // attention to ordering.
161 if let [
162 Import::Alias {
163 module,
164 external_identifier,
165 },
166 ] = imports.as_slice()
167 {
168 match module {
169 Module::Namespace(namespace) => {
170 import_occurrences.push(namespace.occurrences())
171 }
172 Module::SourceExact(path) => import_paths.push(path),
173 Module::SourceFuzzy(path) => import_occurrences.push(Occurrences::new(
174 IdentifierParts::occurrences_in_path(&path),
175 )),
176 }
177 found_external_identifier = Some(&external_identifier);
178 } else {
179 for import in imports {
180 match import {
181 Import::Direct { module } => match module {
182 Module::Namespace(namespace) => {
183 import_occurrences.push(namespace.occurrences())
184 }
185 Module::SourceExact(path) => import_paths.push(path),
186 Module::SourceFuzzy(path) => import_occurrences.push(Occurrences::new(
187 IdentifierParts::occurrences_in_path(&path),
188 )),
189 },
190 Import::Alias { .. } => {}
191 }
192 }
193 }
194 }
195
196 let identifier_to_lookup = found_external_identifier.unwrap_or(&identifier);
197 // TODO: update this to be able to return more declarations? Especially if there is the
198 // ability to quickly filter a large list (based on imports)
199 let identifier_declarations = index
200 .declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier_to_lookup);
201 let declaration_count = identifier_declarations.len();
202
203 if declaration_count == 0 {
204 continue;
205 }
206
207 // TODO: option to filter out other candidates when same file / import match
208 let mut checked_declarations = Vec::with_capacity(declaration_count);
209 for (declaration_id, declaration) in identifier_declarations {
210 match declaration {
211 Declaration::Buffer {
212 buffer_id,
213 declaration: buffer_declaration,
214 ..
215 } => {
216 if buffer_id == ¤t_buffer.remote_id() {
217 let already_included_in_prompt =
218 range_intersection(&buffer_declaration.item_range, &excerpt.range)
219 .is_some()
220 || excerpt
221 .parent_declarations
222 .iter()
223 .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id);
224 if !options.omit_excerpt_overlaps || !already_included_in_prompt {
225 let declaration_line = buffer_declaration
226 .item_range
227 .start
228 .to_point(current_buffer)
229 .row;
230 let declaration_line_distance =
231 (cursor_point.row as i32 - declaration_line as i32).unsigned_abs();
232 checked_declarations.push(CheckedDeclaration {
233 declaration,
234 same_file_line_distance: Some(declaration_line_distance),
235 path_import_match_count: 0,
236 wildcard_path_import_match_count: 0,
237 });
238 }
239 continue;
240 } else {
241 }
242 }
243 Declaration::File { .. } => {}
244 }
245 let declaration_path = declaration.cached_path();
246 let path_import_match_count = import_paths
247 .iter()
248 .filter(|import_path| {
249 declaration_path_matches_import(&declaration_path, import_path)
250 })
251 .count();
252 let wildcard_path_import_match_count = wildcard_import_paths
253 .iter()
254 .filter(|import_path| {
255 declaration_path_matches_import(&declaration_path, import_path)
256 })
257 .count();
258 checked_declarations.push(CheckedDeclaration {
259 declaration,
260 same_file_line_distance: None,
261 path_import_match_count,
262 wildcard_path_import_match_count,
263 });
264 }
265
266 let mut max_import_similarity = 0.0;
267 let mut max_wildcard_import_similarity = 0.0;
268
269 let mut scored_declarations_for_identifier = Vec::with_capacity(checked_declarations.len());
270 for checked_declaration in checked_declarations {
271 let same_file_declaration_count =
272 index.file_declaration_count(checked_declaration.declaration);
273
274 let declaration = score_declaration(
275 &identifier,
276 &references,
277 checked_declaration,
278 same_file_declaration_count,
279 declaration_count,
280 &excerpt_occurrences,
281 &adjacent_occurrences,
282 &import_occurrences,
283 &wildcard_import_occurrences,
284 cursor_point,
285 current_buffer,
286 );
287
288 if declaration.components.import_similarity > max_import_similarity {
289 max_import_similarity = declaration.components.import_similarity;
290 }
291
292 if declaration.components.wildcard_import_similarity > max_wildcard_import_similarity {
293 max_wildcard_import_similarity = declaration.components.wildcard_import_similarity;
294 }
295
296 project_entry_id_to_outline_ranges
297 .entry(declaration.declaration.project_entry_id())
298 .or_default()
299 .push(declaration.declaration.item_range());
300 scored_declarations_for_identifier.push(declaration);
301 }
302
303 if max_import_similarity > 0.0 || max_wildcard_import_similarity > 0.0 {
304 for declaration in scored_declarations_for_identifier.iter_mut() {
305 if max_import_similarity > 0.0 {
306 declaration.components.max_import_similarity = max_import_similarity;
307 declaration.components.normalized_import_similarity =
308 declaration.components.import_similarity / max_import_similarity;
309 }
310 if max_wildcard_import_similarity > 0.0 {
311 declaration.components.normalized_wildcard_import_similarity =
312 declaration.components.wildcard_import_similarity
313 / max_wildcard_import_similarity;
314 }
315 }
316 }
317
318 scored_declarations.extend(scored_declarations_for_identifier);
319 }
320
321 // TODO: Inform this via import / retrieval scores of outline items
322 // TODO: Consider using a sweepline
323 for scored_declaration in scored_declarations.iter_mut() {
324 let project_entry_id = scored_declaration.declaration.project_entry_id();
325 let Some(ranges) = project_entry_id_to_outline_ranges.get(&project_entry_id) else {
326 continue;
327 };
328 for range in ranges {
329 if range.contains_inclusive(&scored_declaration.declaration.item_range()) {
330 scored_declaration.components.included_by_others += 1
331 } else if scored_declaration
332 .declaration
333 .item_range()
334 .contains_inclusive(range)
335 {
336 scored_declaration.components.includes_others += 1
337 }
338 }
339 }
340
341 scored_declarations.sort_unstable_by_key(|declaration| {
342 Reverse(OrderedFloat(
343 declaration.score(DeclarationStyle::Declaration),
344 ))
345 });
346
347 scored_declarations
348}
349
350struct CheckedDeclaration<'a> {
351 declaration: &'a Declaration,
352 same_file_line_distance: Option<u32>,
353 path_import_match_count: usize,
354 wildcard_path_import_match_count: usize,
355}
356
357fn declaration_path_matches_import(
358 declaration_path: &CachedDeclarationPath,
359 import_path: &Arc<Path>,
360) -> bool {
361 if import_path.is_absolute() {
362 declaration_path.equals_absolute_path(import_path)
363 } else {
364 declaration_path.ends_with_posix_path(import_path)
365 }
366}
367
368fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
369 let start = a.start.clone().max(b.start.clone());
370 let end = a.end.clone().min(b.end.clone());
371 if start < end {
372 Some(Range { start, end })
373 } else {
374 None
375 }
376}
377
378fn score_declaration(
379 identifier: &Identifier,
380 references: &[Reference],
381 checked_declaration: CheckedDeclaration,
382 same_file_declaration_count: usize,
383 declaration_count: usize,
384 excerpt_occurrences: &Occurrences<IdentifierParts>,
385 adjacent_occurrences: &Occurrences<IdentifierParts>,
386 import_occurrences: &[Occurrences<IdentifierParts>],
387 wildcard_import_occurrences: &[Occurrences<IdentifierParts>],
388 cursor: Point,
389 current_buffer: &BufferSnapshot,
390) -> ScoredDeclaration {
391 let CheckedDeclaration {
392 declaration,
393 same_file_line_distance,
394 path_import_match_count,
395 wildcard_path_import_match_count,
396 } = checked_declaration;
397
398 let is_referenced_nearby = references
399 .iter()
400 .any(|r| r.region == ReferenceRegion::Nearby);
401 let is_referenced_in_breadcrumb = references
402 .iter()
403 .any(|r| r.region == ReferenceRegion::Breadcrumb);
404 let reference_count = references.len();
405 let reference_line_distance = references
406 .iter()
407 .map(|r| {
408 let reference_line = r.range.start.to_point(current_buffer).row as i32;
409 (cursor.row as i32 - reference_line).unsigned_abs()
410 })
411 .min()
412 .unwrap();
413
414 let is_same_file = same_file_line_distance.is_some();
415 let declaration_line_distance = same_file_line_distance.unwrap_or(u32::MAX);
416
417 let item_source_occurrences = Occurrences::new(IdentifierParts::occurrences_in_str(
418 &declaration.item_text().0,
419 ));
420 let item_signature_occurrences = Occurrences::new(IdentifierParts::occurrences_in_str(
421 &declaration.signature_text().0,
422 ));
423 let excerpt_vs_item_jaccard = excerpt_occurrences.jaccard_similarity(&item_source_occurrences);
424 let excerpt_vs_signature_jaccard =
425 excerpt_occurrences.jaccard_similarity(&item_signature_occurrences);
426 let adjacent_vs_item_jaccard =
427 adjacent_occurrences.jaccard_similarity(&item_source_occurrences);
428 let adjacent_vs_signature_jaccard =
429 adjacent_occurrences.jaccard_similarity(&item_signature_occurrences);
430
431 let excerpt_vs_item_weighted_overlap =
432 excerpt_occurrences.weighted_overlap_coefficient(&item_source_occurrences);
433 let excerpt_vs_signature_weighted_overlap =
434 excerpt_occurrences.weighted_overlap_coefficient(&item_signature_occurrences);
435 let adjacent_vs_item_weighted_overlap =
436 adjacent_occurrences.weighted_overlap_coefficient(&item_source_occurrences);
437 let adjacent_vs_signature_weighted_overlap =
438 adjacent_occurrences.weighted_overlap_coefficient(&item_signature_occurrences);
439
440 let mut import_similarity = 0f32;
441 let mut wildcard_import_similarity = 0f32;
442 if !import_occurrences.is_empty() || !wildcard_import_occurrences.is_empty() {
443 let cached_path = declaration.cached_path();
444 let path_occurrences = Occurrences::new(IdentifierParts::occurrences_in_worktree_path(
445 cached_path
446 .worktree_abs_path
447 .file_name()
448 .map(|f| f.to_string_lossy()),
449 &cached_path.rel_path,
450 ));
451 import_similarity = import_occurrences
452 .iter()
453 .map(|namespace_occurrences| {
454 OrderedFloat(namespace_occurrences.jaccard_similarity(&path_occurrences))
455 })
456 .max()
457 .map(|similarity| similarity.into_inner())
458 .unwrap_or_default();
459
460 // TODO: Consider something other than max
461 wildcard_import_similarity = wildcard_import_occurrences
462 .iter()
463 .map(|namespace_occurrences| {
464 OrderedFloat(namespace_occurrences.jaccard_similarity(&path_occurrences))
465 })
466 .max()
467 .map(|similarity| similarity.into_inner())
468 .unwrap_or_default();
469 }
470
471 // TODO: Consider adding declaration_file_count
472 let score_components = DeclarationScoreComponents {
473 is_same_file,
474 is_referenced_nearby,
475 is_referenced_in_breadcrumb,
476 reference_line_distance,
477 declaration_line_distance,
478 reference_count,
479 same_file_declaration_count,
480 declaration_count,
481 excerpt_vs_item_jaccard,
482 excerpt_vs_signature_jaccard,
483 adjacent_vs_item_jaccard,
484 adjacent_vs_signature_jaccard,
485 excerpt_vs_item_weighted_overlap,
486 excerpt_vs_signature_weighted_overlap,
487 adjacent_vs_item_weighted_overlap,
488 adjacent_vs_signature_weighted_overlap,
489 path_import_match_count,
490 wildcard_path_import_match_count,
491 import_similarity,
492 max_import_similarity: 0.0,
493 normalized_import_similarity: 0.0,
494 wildcard_import_similarity,
495 normalized_wildcard_import_similarity: 0.0,
496 included_by_others: 0,
497 includes_others: 0,
498 };
499
500 ScoredDeclaration {
501 identifier: identifier.clone(),
502 declaration: declaration.clone(),
503 components: score_components,
504 }
505}
506
507#[cfg(test)]
508mod test {
509 use super::*;
510
511 #[test]
512 fn test_declaration_path_matches() {
513 let declaration_path =
514 CachedDeclarationPath::new_for_test("/home/user/project", "src/maths.ts");
515
516 assert!(declaration_path_matches_import(
517 &declaration_path,
518 &Path::new("maths.ts").into()
519 ));
520
521 assert!(declaration_path_matches_import(
522 &declaration_path,
523 &Path::new("project/src/maths.ts").into()
524 ));
525
526 assert!(declaration_path_matches_import(
527 &declaration_path,
528 &Path::new("user/project/src/maths.ts").into()
529 ));
530
531 assert!(declaration_path_matches_import(
532 &declaration_path,
533 &Path::new("/home/user/project/src/maths.ts").into()
534 ));
535
536 assert!(!declaration_path_matches_import(
537 &declaration_path,
538 &Path::new("other.ts").into()
539 ));
540
541 assert!(!declaration_path_matches_import(
542 &declaration_path,
543 &Path::new("/home/user/project/src/other.ts").into()
544 ));
545 }
546}