1use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
2use collections::HashMap;
3use language::BufferSnapshot;
4use ordered_float::OrderedFloat;
5use project::ProjectEntryId;
6use serde::Serialize;
7use std::{cmp::Reverse, ops::Range, path::Path, sync::Arc};
8use strum::EnumIter;
9use text::{Point, ToPoint};
10use util::RangeExt as _;
11
12use crate::{
13 CachedDeclarationPath, Declaration, EditPredictionExcerpt, Identifier,
14 imports::{Import, Imports, Module},
15 reference::{Reference, ReferenceRegion},
16 syntax_index::SyntaxIndexState,
17 text_similarity::{Occurrences, jaccard_similarity, weighted_overlap_coefficient},
18};
19
20const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
21
22#[derive(Clone, Debug, PartialEq, Eq)]
23pub struct EditPredictionScoreOptions {
24 pub omit_excerpt_overlaps: bool,
25}
26
27#[derive(Clone, Debug)]
28pub struct ScoredDeclaration {
29 /// identifier used by the local reference
30 pub identifier: Identifier,
31 pub declaration: Declaration,
32 pub components: DeclarationScoreComponents,
33}
34
35#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
36pub enum DeclarationStyle {
37 Signature,
38 Declaration,
39}
40
41#[derive(Clone, Debug, Serialize, Default)]
42pub struct DeclarationScores {
43 pub signature: f32,
44 pub declaration: f32,
45 pub retrieval: f32,
46}
47
48impl ScoredDeclaration {
49 /// Returns the score for this declaration with the specified style.
50 pub fn score(&self, style: DeclarationStyle) -> f32 {
51 // TODO: handle truncation
52
53 // Score related to how likely this is the correct declaration, range 0 to 1
54 let retrieval = self.retrieval_score();
55
56 // Score related to the distance between the reference and cursor, range 0 to 1
57 let distance_score = if self.components.is_referenced_nearby {
58 1.0 / (1.0 + self.components.reference_line_distance as f32 / 10.0).powf(2.0)
59 } else {
60 // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
61 0.5
62 };
63
64 // For now instead of linear combination, the scores are just multiplied together.
65 let combined_score = 10.0 * retrieval * distance_score;
66
67 match style {
68 DeclarationStyle::Signature => {
69 combined_score * self.components.excerpt_vs_signature_weighted_overlap
70 }
71 DeclarationStyle::Declaration => {
72 2.0 * combined_score * self.components.excerpt_vs_item_weighted_overlap
73 }
74 }
75 }
76
77 pub fn retrieval_score(&self) -> f32 {
78 let mut score = if self.components.is_same_file {
79 10.0 / self.components.same_file_declaration_count as f32
80 } else if self.components.path_import_match_count > 0 {
81 3.0
82 } else if self.components.wildcard_path_import_match_count > 0 {
83 1.0
84 } else if self.components.normalized_import_similarity > 0.0 {
85 self.components.normalized_import_similarity
86 } else if self.components.normalized_wildcard_import_similarity > 0.0 {
87 0.5 * self.components.normalized_wildcard_import_similarity
88 } else {
89 1.0 / self.components.declaration_count as f32
90 };
91 score *= 1. + self.components.included_by_others as f32 / 2.;
92 score *= 1. + self.components.includes_others as f32 / 4.;
93 score
94 }
95
96 pub fn size(&self, style: DeclarationStyle) -> usize {
97 match &self.declaration {
98 Declaration::File { declaration, .. } => match style {
99 DeclarationStyle::Signature => declaration.signature_range.len(),
100 DeclarationStyle::Declaration => declaration.text.len(),
101 },
102 Declaration::Buffer { declaration, .. } => match style {
103 DeclarationStyle::Signature => declaration.signature_range.len(),
104 DeclarationStyle::Declaration => declaration.item_range.len(),
105 },
106 }
107 }
108
109 pub fn score_density(&self, style: DeclarationStyle) -> f32 {
110 self.score(style) / self.size(style) as f32
111 }
112}
113
114pub fn scored_declarations(
115 options: &EditPredictionScoreOptions,
116 index: &SyntaxIndexState,
117 excerpt: &EditPredictionExcerpt,
118 excerpt_occurrences: &Occurrences,
119 adjacent_occurrences: &Occurrences,
120 imports: &Imports,
121 identifier_to_references: HashMap<Identifier, Vec<Reference>>,
122 cursor_offset: usize,
123 current_buffer: &BufferSnapshot,
124) -> Vec<ScoredDeclaration> {
125 let cursor_point = cursor_offset.to_point(¤t_buffer);
126
127 let mut wildcard_import_occurrences = Vec::new();
128 let mut wildcard_import_paths = Vec::new();
129 for wildcard_import in imports.wildcard_modules.iter() {
130 match wildcard_import {
131 Module::Namespace(namespace) => {
132 wildcard_import_occurrences.push(namespace.occurrences())
133 }
134 Module::SourceExact(path) => wildcard_import_paths.push(path),
135 Module::SourceFuzzy(path) => {
136 wildcard_import_occurrences.push(Occurrences::from_path(&path))
137 }
138 }
139 }
140
141 let mut scored_declarations = Vec::new();
142 let mut project_entry_id_to_outline_ranges: HashMap<ProjectEntryId, Vec<Range<usize>>> =
143 HashMap::default();
144 for (identifier, references) in identifier_to_references {
145 let mut import_occurrences = Vec::new();
146 let mut import_paths = Vec::new();
147 let mut found_external_identifier: Option<&Identifier> = None;
148
149 if let Some(imports) = imports.identifier_to_imports.get(&identifier) {
150 // only use alias when it's the only import, could be generalized if some language
151 // has overlapping aliases
152 //
153 // TODO: when an aliased declaration is included in the prompt, should include the
154 // aliasing in the prompt.
155 //
156 // TODO: For SourceFuzzy consider having componentwise comparison that pays
157 // attention to ordering.
158 if let [
159 Import::Alias {
160 module,
161 external_identifier,
162 },
163 ] = imports.as_slice()
164 {
165 match module {
166 Module::Namespace(namespace) => {
167 import_occurrences.push(namespace.occurrences())
168 }
169 Module::SourceExact(path) => import_paths.push(path),
170 Module::SourceFuzzy(path) => {
171 import_occurrences.push(Occurrences::from_path(&path))
172 }
173 }
174 found_external_identifier = Some(&external_identifier);
175 } else {
176 for import in imports {
177 match import {
178 Import::Direct { module } => match module {
179 Module::Namespace(namespace) => {
180 import_occurrences.push(namespace.occurrences())
181 }
182 Module::SourceExact(path) => import_paths.push(path),
183 Module::SourceFuzzy(path) => {
184 import_occurrences.push(Occurrences::from_path(&path))
185 }
186 },
187 Import::Alias { .. } => {}
188 }
189 }
190 }
191 }
192
193 let identifier_to_lookup = found_external_identifier.unwrap_or(&identifier);
194 // TODO: update this to be able to return more declarations? Especially if there is the
195 // ability to quickly filter a large list (based on imports)
196 let identifier_declarations = index
197 .declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier_to_lookup);
198 let declaration_count = identifier_declarations.len();
199
200 if declaration_count == 0 {
201 continue;
202 }
203
204 // TODO: option to filter out other candidates when same file / import match
205 let mut checked_declarations = Vec::with_capacity(declaration_count);
206 for (declaration_id, declaration) in identifier_declarations {
207 match declaration {
208 Declaration::Buffer {
209 buffer_id,
210 declaration: buffer_declaration,
211 ..
212 } => {
213 if buffer_id == ¤t_buffer.remote_id() {
214 let already_included_in_prompt =
215 range_intersection(&buffer_declaration.item_range, &excerpt.range)
216 .is_some()
217 || excerpt
218 .parent_declarations
219 .iter()
220 .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id);
221 if !options.omit_excerpt_overlaps || !already_included_in_prompt {
222 let declaration_line = buffer_declaration
223 .item_range
224 .start
225 .to_point(current_buffer)
226 .row;
227 let declaration_line_distance =
228 (cursor_point.row as i32 - declaration_line as i32).unsigned_abs();
229 checked_declarations.push(CheckedDeclaration {
230 declaration,
231 same_file_line_distance: Some(declaration_line_distance),
232 path_import_match_count: 0,
233 wildcard_path_import_match_count: 0,
234 });
235 }
236 continue;
237 } else {
238 }
239 }
240 Declaration::File { .. } => {}
241 }
242 let declaration_path = declaration.cached_path();
243 let path_import_match_count = import_paths
244 .iter()
245 .filter(|import_path| {
246 declaration_path_matches_import(&declaration_path, import_path)
247 })
248 .count();
249 let wildcard_path_import_match_count = wildcard_import_paths
250 .iter()
251 .filter(|import_path| {
252 declaration_path_matches_import(&declaration_path, import_path)
253 })
254 .count();
255 checked_declarations.push(CheckedDeclaration {
256 declaration,
257 same_file_line_distance: None,
258 path_import_match_count,
259 wildcard_path_import_match_count,
260 });
261 }
262
263 let mut max_import_similarity = 0.0;
264 let mut max_wildcard_import_similarity = 0.0;
265
266 let mut scored_declarations_for_identifier = Vec::with_capacity(checked_declarations.len());
267 for checked_declaration in checked_declarations {
268 let same_file_declaration_count =
269 index.file_declaration_count(checked_declaration.declaration);
270
271 let declaration = score_declaration(
272 &identifier,
273 &references,
274 checked_declaration,
275 same_file_declaration_count,
276 declaration_count,
277 &excerpt_occurrences,
278 &adjacent_occurrences,
279 &import_occurrences,
280 &wildcard_import_occurrences,
281 cursor_point,
282 current_buffer,
283 );
284
285 if declaration.components.import_similarity > max_import_similarity {
286 max_import_similarity = declaration.components.import_similarity;
287 }
288
289 if declaration.components.wildcard_import_similarity > max_wildcard_import_similarity {
290 max_wildcard_import_similarity = declaration.components.wildcard_import_similarity;
291 }
292
293 project_entry_id_to_outline_ranges
294 .entry(declaration.declaration.project_entry_id())
295 .or_default()
296 .push(declaration.declaration.item_range());
297 scored_declarations_for_identifier.push(declaration);
298 }
299
300 if max_import_similarity > 0.0 || max_wildcard_import_similarity > 0.0 {
301 for declaration in scored_declarations_for_identifier.iter_mut() {
302 if max_import_similarity > 0.0 {
303 declaration.components.max_import_similarity = max_import_similarity;
304 declaration.components.normalized_import_similarity =
305 declaration.components.import_similarity / max_import_similarity;
306 }
307 if max_wildcard_import_similarity > 0.0 {
308 declaration.components.normalized_wildcard_import_similarity =
309 declaration.components.wildcard_import_similarity
310 / max_wildcard_import_similarity;
311 }
312 }
313 }
314
315 scored_declarations.extend(scored_declarations_for_identifier);
316 }
317
318 // TODO: Inform this via import / retrieval scores of outline items
319 // TODO: Consider using a sweepline
320 for scored_declaration in scored_declarations.iter_mut() {
321 let project_entry_id = scored_declaration.declaration.project_entry_id();
322 let Some(ranges) = project_entry_id_to_outline_ranges.get(&project_entry_id) else {
323 continue;
324 };
325 for range in ranges {
326 if range.contains_inclusive(&scored_declaration.declaration.item_range()) {
327 scored_declaration.components.included_by_others += 1
328 } else if scored_declaration
329 .declaration
330 .item_range()
331 .contains_inclusive(range)
332 {
333 scored_declaration.components.includes_others += 1
334 }
335 }
336 }
337
338 scored_declarations.sort_unstable_by_key(|declaration| {
339 Reverse(OrderedFloat(
340 declaration.score(DeclarationStyle::Declaration),
341 ))
342 });
343
344 scored_declarations
345}
346
347struct CheckedDeclaration<'a> {
348 declaration: &'a Declaration,
349 same_file_line_distance: Option<u32>,
350 path_import_match_count: usize,
351 wildcard_path_import_match_count: usize,
352}
353
354fn declaration_path_matches_import(
355 declaration_path: &CachedDeclarationPath,
356 import_path: &Arc<Path>,
357) -> bool {
358 if import_path.is_absolute() {
359 declaration_path.equals_absolute_path(import_path)
360 } else {
361 declaration_path.ends_with_posix_path(import_path)
362 }
363}
364
365fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
366 let start = a.start.clone().max(b.start.clone());
367 let end = a.end.clone().min(b.end.clone());
368 if start < end {
369 Some(Range { start, end })
370 } else {
371 None
372 }
373}
374
375fn score_declaration(
376 identifier: &Identifier,
377 references: &[Reference],
378 checked_declaration: CheckedDeclaration,
379 same_file_declaration_count: usize,
380 declaration_count: usize,
381 excerpt_occurrences: &Occurrences,
382 adjacent_occurrences: &Occurrences,
383 import_occurrences: &[Occurrences],
384 wildcard_import_occurrences: &[Occurrences],
385 cursor: Point,
386 current_buffer: &BufferSnapshot,
387) -> ScoredDeclaration {
388 let CheckedDeclaration {
389 declaration,
390 same_file_line_distance,
391 path_import_match_count,
392 wildcard_path_import_match_count,
393 } = checked_declaration;
394
395 let is_referenced_nearby = references
396 .iter()
397 .any(|r| r.region == ReferenceRegion::Nearby);
398 let is_referenced_in_breadcrumb = references
399 .iter()
400 .any(|r| r.region == ReferenceRegion::Breadcrumb);
401 let reference_count = references.len();
402 let reference_line_distance = references
403 .iter()
404 .map(|r| {
405 let reference_line = r.range.start.to_point(current_buffer).row as i32;
406 (cursor.row as i32 - reference_line).unsigned_abs()
407 })
408 .min()
409 .unwrap();
410
411 let is_same_file = same_file_line_distance.is_some();
412 let declaration_line_distance = same_file_line_distance.unwrap_or(u32::MAX);
413
414 let item_source_occurrences = Occurrences::within_string(&declaration.item_text().0);
415 let item_signature_occurrences = Occurrences::within_string(&declaration.signature_text().0);
416 let excerpt_vs_item_jaccard = jaccard_similarity(excerpt_occurrences, &item_source_occurrences);
417 let excerpt_vs_signature_jaccard =
418 jaccard_similarity(excerpt_occurrences, &item_signature_occurrences);
419 let adjacent_vs_item_jaccard =
420 jaccard_similarity(adjacent_occurrences, &item_source_occurrences);
421 let adjacent_vs_signature_jaccard =
422 jaccard_similarity(adjacent_occurrences, &item_signature_occurrences);
423
424 let excerpt_vs_item_weighted_overlap =
425 weighted_overlap_coefficient(excerpt_occurrences, &item_source_occurrences);
426 let excerpt_vs_signature_weighted_overlap =
427 weighted_overlap_coefficient(excerpt_occurrences, &item_signature_occurrences);
428 let adjacent_vs_item_weighted_overlap =
429 weighted_overlap_coefficient(adjacent_occurrences, &item_source_occurrences);
430 let adjacent_vs_signature_weighted_overlap =
431 weighted_overlap_coefficient(adjacent_occurrences, &item_signature_occurrences);
432
433 let mut import_similarity = 0f32;
434 let mut wildcard_import_similarity = 0f32;
435 if !import_occurrences.is_empty() || !wildcard_import_occurrences.is_empty() {
436 let cached_path = declaration.cached_path();
437 let path_occurrences = Occurrences::from_worktree_path(
438 cached_path
439 .worktree_abs_path
440 .file_name()
441 .map(|f| f.to_string_lossy()),
442 &cached_path.rel_path,
443 );
444 import_similarity = import_occurrences
445 .iter()
446 .map(|namespace_occurrences| {
447 OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences))
448 })
449 .max()
450 .map(|similarity| similarity.into_inner())
451 .unwrap_or_default();
452
453 // TODO: Consider something other than max
454 wildcard_import_similarity = wildcard_import_occurrences
455 .iter()
456 .map(|namespace_occurrences| {
457 OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences))
458 })
459 .max()
460 .map(|similarity| similarity.into_inner())
461 .unwrap_or_default();
462 }
463
464 // TODO: Consider adding declaration_file_count
465 let score_components = DeclarationScoreComponents {
466 is_same_file,
467 is_referenced_nearby,
468 is_referenced_in_breadcrumb,
469 reference_line_distance,
470 declaration_line_distance,
471 reference_count,
472 same_file_declaration_count,
473 declaration_count,
474 excerpt_vs_item_jaccard,
475 excerpt_vs_signature_jaccard,
476 adjacent_vs_item_jaccard,
477 adjacent_vs_signature_jaccard,
478 excerpt_vs_item_weighted_overlap,
479 excerpt_vs_signature_weighted_overlap,
480 adjacent_vs_item_weighted_overlap,
481 adjacent_vs_signature_weighted_overlap,
482 path_import_match_count,
483 wildcard_path_import_match_count,
484 import_similarity,
485 max_import_similarity: 0.0,
486 normalized_import_similarity: 0.0,
487 wildcard_import_similarity,
488 normalized_wildcard_import_similarity: 0.0,
489 included_by_others: 0,
490 includes_others: 0,
491 };
492
493 ScoredDeclaration {
494 identifier: identifier.clone(),
495 declaration: declaration.clone(),
496 components: score_components,
497 }
498}
499
500#[cfg(test)]
501mod test {
502 use super::*;
503
504 #[test]
505 fn test_declaration_path_matches() {
506 let declaration_path =
507 CachedDeclarationPath::new_for_test("/home/user/project", "src/maths.ts");
508
509 assert!(declaration_path_matches_import(
510 &declaration_path,
511 &Path::new("maths.ts").into()
512 ));
513
514 assert!(declaration_path_matches_import(
515 &declaration_path,
516 &Path::new("project/src/maths.ts").into()
517 ));
518
519 assert!(declaration_path_matches_import(
520 &declaration_path,
521 &Path::new("user/project/src/maths.ts").into()
522 ));
523
524 assert!(declaration_path_matches_import(
525 &declaration_path,
526 &Path::new("/home/user/project/src/maths.ts").into()
527 ));
528
529 assert!(!declaration_path_matches_import(
530 &declaration_path,
531 &Path::new("other.ts").into()
532 ));
533
534 assert!(!declaration_path_matches_import(
535 &declaration_path,
536 &Path::new("/home/user/project/src/other.ts").into()
537 ));
538 }
539}